Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 172 additions & 1 deletion tests/train/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from pathlib import Path
from unittest.mock import patch, Mock
import pickle
import shutil
import weakref
from pydantic import TypeAdapter

import torch
import torch.distributed as dist
Expand All @@ -18,11 +21,13 @@
InternVL3P5Dense8BConfig,
InternVL3P5MoE30BA3Config,
)
from xtuner.v1.train.trainer import Trainer, ResumeConfig
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage
from xtuner.v1.datasets import FTDPTokenizeFnConfig
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
from xtuner.v1.train.trainer import TrainerConfig
from xtuner.v1.engine.train_engine import LossLog, OtherLog
from xtuner.v1.loss import CELossConfig
from xtuner._testing import DeterministicDDPTestCase
from unittest import TestCase

from xtuner.v1.utils.device import get_device
Expand Down Expand Up @@ -441,3 +446,169 @@ def _dump_trainer_config(self, trainer_cfg: TrainerConfig):
trainer_cfg.model_dump_json()
trainer_cfg.model_dump()
pickle.dumps(trainer_cfg)


class CheckpointHookPickle:
def __init__(self) -> None:
self.count = 0

def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
self.count += 1


class TestHooksConfig(DeterministicDDPTestCase):
TOTAL_STEP = 10
CHECKPOINT_INTERVAL = 5
SNAPSHOT_INTERVAL = 2
HF_INTERVAL = 10
ERROR_MESG_PREFIX="[HooksConfig Test Failed]: "

def _build_trainer(self, hooks_config: HooksConfig):
model_cfg = Qwen3MoE30BA3Config(num_hidden_layers=2, hidden_size=1024, moe_intermediate_size=384)
dataset_config = [
{
"dataset": DatasetConfig(name="alpaca", anno_path=os.environ["ALPACA_PATH"], sample_ratio=1.0),
"tokenize_fn": OpenaiTokenizeFunctionConfig(
max_length=100, chat_template="qwen3"
),
# "tokenize_fn": FTDPTokenizeFnConfig(max_length=16386),
},
]
dataloader_config = DataloaderConfig(pack_max_length=100)

optim_cfg = AdamWConfig(lr=0.1, weight_decay=0.1)
lr_cfg = LRConfig(lr_type="cosine", lr_min=0.001, warmup_ratio=0.03)

work_dir = tempfile.TemporaryDirectory().name
if dist.get_rank() == 0:
work_dir_list = [work_dir]
else:
work_dir_list = [None]
dist.broadcast_object_list(work_dir_list, src=0)
work_dir = work_dir_list[0]

trainer_cfg = TrainerConfig(
model_cfg=model_cfg,
optim_cfg=optim_cfg,
dataset_cfg=dataset_config,
dataloader_cfg=dataloader_config,
lr_cfg=lr_cfg,
loss_cfg=CELossConfig(mode="chunk", chunk_size=1024),
global_batch_size=self.world_size,
sp_size=1,
total_step=self.TOTAL_STEP,
seed=42,
checkpoint_interval=self.CHECKPOINT_INTERVAL,
snapshot_interval=self.SNAPSHOT_INTERVAL,
hf_interval=self.HF_INTERVAL,
tokenizer_path=os.environ["QWEN3_MOE_PATH"],
work_dir=work_dir,
hooks_config=hooks_config,
)
return Trainer.from_config(trainer_cfg)

def _cleanup_trainer(self, trainer: Trainer):
if dist.get_rank() == 0:
shutil.rmtree(trainer.work_dir, ignore_errors=True)
dist.barrier()

def test_hooks_config(self):
self.create_pg(DEVICE)
checkpoint_function_call_times = 0
train_step_function_call_times = 0
losslog_adapater = TypeAdapter(LossLog)
otherlog_adapter = TypeAdapter(OtherLog)

def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
nonlocal checkpoint_function_call_times
checkpoint_function_call_times += 1

def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch):
nonlocal train_step_function_call_times
train_step_function_call_times += 1


class CheckpointHook:
def __init__(self) -> None:
self.count = 0

def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
self.count += 1

class TrainStepHook:
def connect_trainer(self, trainer: Trainer):
self.trainer = weakref.ref(trainer)

def __init__(self) -> None:
self.count = 0

def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
losslog_adapater.validate_python(loss_log)
otherlog_adapter.validate_python(other_log)

assert self.trainer().cur_step == step
assert self.trainer().cur_epoch == epoch
assert self.trainer().total_step == total_step
assert self.trainer().total_epoch == total_epoch

self.count += 1

hooks_config = HooksConfig(
after_save_dcp=[checkpoint_hook, CheckpointHook()],
after_train_step=[train_step_hook, TrainStepHook()],
after_save_hf=CheckpointHook(),
after_save_snapshot=CheckpointHook(),
)
trainer = self._build_trainer(hooks_config)
trainer.fit()

self.assertEqual(
checkpoint_function_call_times,
2,
self.ERROR_MESG_PREFIX + "Checkpoint hook not called expected times",
)
self.assertEqual(
train_step_function_call_times,
10,
self.ERROR_MESG_PREFIX + "Train step hook not called expected times",
)
self.assertEqual(
hooks_config.get_hooks(HookStage.AFTER_TRAIN_STEP)[1].count,
10,
self.ERROR_MESG_PREFIX + "Train step hook not called expected times",
)
self.assertEqual(
hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP)[1].count,
2,
self.ERROR_MESG_PREFIX + "Checkpoint hook not called expected times",
)
self.assertEqual(
hooks_config.get_hooks(HookStage.AFTER_SAVE_HF)[0].count,
1,
self.ERROR_MESG_PREFIX + "HF checkpoint hook not called expected times",
)
# The last snapshot will not be saved fod dcp has been saved.
self.assertEqual(
hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)[0].count,
4,
self.ERROR_MESG_PREFIX + "Snapshot hook not called expected times",
)
self._cleanup_trainer(trainer)

def test_serialize_hooks_config(self):
self.create_pg(DEVICE)
class CheckpointHook:
def __init__(self) -> None:
self.count = 0

def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
self.count += 1

hooks_config = HooksConfig(
after_train_step=CheckpointHook(),
after_save_dcp=CheckpointHookPickle(),
)
dumped = pickle.dumps(hooks_config)
loaded = pickle.loads(dumped)
assert len(loaded.get_hooks(HookStage.AFTER_TRAIN_STEP)) == 0 # <local> object cannot be serialized
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1
5 changes: 5 additions & 0 deletions xtuner/_testing/testcase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.testing._internal.common_distributed import DistributedTestBase, MultiProcessTestCase, logger, TEST_SKIPS, c10d
import torch
import torch.distributed as dist
import threading
import sys
import os
Expand Down Expand Up @@ -91,3 +92,7 @@ def _check_loss_curve(
raise AssertionError(
f"Failed to check relative error of loss, expected: {losses_ref}, got {losses}, Mean diff: {avg_relative_diff}")

def create_pg(self, device):
ret = super().create_pg(device)
os.environ["LOCAL_RANK"] = str(dist.get_rank() % torch.cuda.device_count())
return ret
4 changes: 3 additions & 1 deletion xtuner/v1/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from xtuner.v1.engine.config import EngineConfig

from .train_engine import TrainEngine
from .train_engine import LossLog, OtherLog, TrainEngine
from .vision_compose_train_engine import VisionComposeTrainEngine


__all__ = [
"TrainEngine",
"EngineConfig",
"VisionComposeTrainEngine",
"LossLog",
"OtherLog",
]
24 changes: 21 additions & 3 deletions xtuner/v1/engine/train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from pydantic import ConfigDict
from safetensors import safe_open
from torch.distributed.checkpoint.state_dict import (
StateDictOptions,
Expand All @@ -21,6 +22,7 @@
from torch.utils._foreach_utils import (
_device_has_foreach_support,
)
from typing_extensions import NotRequired, TypedDict

from xtuner.v1.config import FSDPConfig, OptimConfig
from xtuner.v1.data_proto.sequence_context import SequenceContext
Expand All @@ -40,6 +42,22 @@
threading_lock = threading.Lock()


class LossLog(TypedDict):
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
total_loss: float
reduced_llm_loss: float
reduced_balancing_loss: NotRequired[float]
reduced_z_loss: NotRequired[float]


class OtherLog(TypedDict):
__pydantic_config__ = ConfigDict(arbitrary_types_allowed=True) # type: ignore[misc]
maxvio: NotRequired[float]
consumed_tokens: float
extra_info: ModelForwardExtraLogInfo
efficient_attn_ratio: float


class CPUThreadTaskCoordinator:
def __init__(self, futures, callback):
self.futures = futures
Expand Down Expand Up @@ -199,7 +217,7 @@ def grad_accumulation_steps(self, data_batches_len: int):
intra_layer_micro_batch = self.intra_layer_micro_batch
return data_batches_len // intra_layer_micro_batch

def train_step(self, data_batches: list[ModelItem]):
def train_step(self, data_batches: list[ModelItem]) -> tuple[LossLog, OtherLog]:
"""Perform a training step with the given data batches and mesh.

Args:
Expand All @@ -208,8 +226,8 @@ def train_step(self, data_batches: list[ModelItem]):
if self.float8_handler is not None and self.float8_handler.enabled:
self.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model)

loss_log = {}
other_log: Dict[str, Any] = {}
loss_log: LossLog = {} # type: ignore[typeddict-item]
other_log: OtherLog = {} # type: ignore[typeddict-item]
intra_layer_micro_batch = self.intra_layer_micro_batch
assert len(data_batches) % intra_layer_micro_batch == 0, (
f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}"
Expand Down
8 changes: 4 additions & 4 deletions xtuner/v1/engine/vision_compose_train_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from xtuner.v1.module.router import NoAuxRouterConfig
from xtuner.v1.utils import get_device, get_logger, get_torch_device_module

from .train_engine import TrainEngine
from .train_engine import LossLog, OtherLog, TrainEngine


logger = get_logger()
Expand Down Expand Up @@ -146,7 +146,7 @@ def save_hf(self, hf_dir: str, save_dtype: torch.dtype = torch.bfloat16):
if self._processor is not None:
self._processor.save_pretrained(hf_dir)

def train_step(self, data_batches: List[ModelItem]):
def train_step(self, data_batches: List[ModelItem]) -> tuple[LossLog, OtherLog]:
"""Perform a training step with the given data batches and mesh.

Args:
Expand All @@ -159,8 +159,8 @@ def train_step(self, data_batches: List[ModelItem]):
if self.projector_float8_handler is not None and self.projector_float8_handler.enabled:
self.projector_float8_handler.precompute_float8_dynamic_scale_for_fsdp(self.model.multi_modal_projector)

loss_log = {}
other_log = {}
loss_log: LossLog = {} # type: ignore[typeddict-item]
other_log: OtherLog = {} # type: ignore[typeddict-item]
intra_layer_micro_batch = self.intra_layer_micro_batch
assert len(data_batches) % intra_layer_micro_batch == 0, (
f"data_batches length {len(data_batches)} is not divisible by intra_layer_micro_batch {intra_layer_micro_batch}"
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,10 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
loss_log, other_log = self._engine.train_step(
data_batches=engine_input,
)
other_log = self._update_other_log(other_log)
other_log = self._update_other_log(other_log) # type: ignore[arg-type]
grad_norm = self._engine.clip_grad_norm()
self._engine.step_optimizer(grad_norm)
log_info = dict()
log_info = dict() # type: ignore[var-annotated]
log_info.update(loss_log)
for k, v in other_log.items():
if k == "extra_info":
Expand Down
Loading