Skip to content

Commit df75250

Browse files
committed
[CI] Update unittest for hook
1 parent 5300790 commit df75250

File tree

1 file changed

+120
-1
lines changed

1 file changed

+120
-1
lines changed

tests/train/test_trainer.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from pathlib import Path
44
from unittest.mock import patch, Mock
55
import pickle
6+
import shutil
7+
import weakref
8+
from pydantic import TypeAdapter
69

710
import torch
811
import torch.distributed as dist
@@ -18,11 +21,13 @@
1821
InternVL3P5Dense8BConfig,
1922
InternVL3P5MoE30BA3Config,
2023
)
21-
from xtuner.v1.train.trainer import Trainer, ResumeConfig
24+
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage
2225
from xtuner.v1.datasets import FTDPTokenizeFnConfig
2326
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2427
from xtuner.v1.train.trainer import TrainerConfig
28+
from xtuner.v1.engine.train_engine import LossLog, OtherLog
2529
from xtuner.v1.loss import CELossConfig
30+
from xtuner._testing import DeterministicDDPTestCase
2631
from unittest import TestCase
2732

2833
from xtuner.v1.utils.device import get_device
@@ -440,3 +445,117 @@ def _dump_trainer_config(self, trainer_cfg: TrainerConfig):
440445
trainer_cfg.model_dump_json()
441446
trainer_cfg.model_dump()
442447
pickle.dumps(trainer_cfg)
448+
449+
450+
class TestHooksConfig(DeterministicDDPTestCase):
451+
TOTAL_STEP = 10
452+
CHECKPOINT_INTERVAL = 5
453+
SNAPSHOT_INTERVAL = 2
454+
HF_INTERVAL = 10
455+
456+
def _build_trainer(self, hooks_config: HooksConfig):
457+
model_cfg = Qwen3MoE30BA3Config(num_hidden_layers=2, hidden_size=1024, moe_intermediate_size=384)
458+
dataset_config = [
459+
{
460+
"dataset": DatasetConfig(name="alpaca", anno_path=os.environ["ALPACA_PATH"], sample_ratio=1.0),
461+
"tokenize_fn": OpenaiTokenizeFunctionConfig(
462+
max_length=100, chat_template="qwen3"
463+
),
464+
# "tokenize_fn": FTDPTokenizeFnConfig(max_length=16386),
465+
},
466+
]
467+
dataloader_config = DataloaderConfig(pack_max_length=100)
468+
469+
optim_cfg = AdamWConfig(lr=0.1, weight_decay=0.1)
470+
lr_cfg = LRConfig(lr_type="cosine", lr_min=0.001, warmup_ratio=0.03)
471+
472+
work_dir = tempfile.TemporaryDirectory().name
473+
if dist.get_rank() == 0:
474+
work_dir_list = [work_dir]
475+
else:
476+
work_dir_list = [None]
477+
dist.broadcast_object_list(work_dir_list, src=0)
478+
work_dir = work_dir_list[0]
479+
480+
trainer_cfg = TrainerConfig(
481+
model_cfg=model_cfg,
482+
optim_cfg=optim_cfg,
483+
dataset_cfg=dataset_config,
484+
dataloader_cfg=dataloader_config,
485+
lr_cfg=lr_cfg,
486+
loss_cfg=CELossConfig(mode="chunk", chunk_size=1024),
487+
global_batch_size=self.world_size,
488+
sp_size=1,
489+
total_step=self.TOTAL_STEP,
490+
seed=42,
491+
checkpoint_interval=self.CHECKPOINT_INTERVAL,
492+
snapshot_interval=self.SNAPSHOT_INTERVAL,
493+
hf_interval=self.HF_INTERVAL,
494+
tokenizer_path=os.environ["QWEN3_MOE_PATH"],
495+
work_dir=work_dir,
496+
hooks_config=hooks_config,
497+
)
498+
return Trainer.from_config(trainer_cfg)
499+
500+
def _cleanup_trainer(self, trainer: Trainer):
501+
if dist.get_rank() == 0:
502+
shutil.rmtree(trainer.work_dir, ignore_errors=True)
503+
dist.barrier()
504+
505+
def test_hooks_config(self):
506+
self.create_pg(DEVICE)
507+
checkpoint_function_call_times = 0
508+
train_step_function_call_times = 0
509+
losslog_adapater = TypeAdapter(LossLog)
510+
otherlog_adapter = TypeAdapter(OtherLog)
511+
512+
def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
513+
nonlocal checkpoint_function_call_times
514+
checkpoint_function_call_times += 1
515+
516+
def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch):
517+
nonlocal train_step_function_call_times
518+
train_step_function_call_times += 1
519+
520+
521+
class CheckpointHook:
522+
def __init__(self) -> None:
523+
self.count = 0
524+
525+
def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
526+
self.count += 1
527+
528+
class TrainStepHook:
529+
def connect_trainer(self, trainer: Trainer):
530+
self.trainer = weakref.ref(trainer)
531+
532+
def __init__(self) -> None:
533+
self.count = 0
534+
535+
def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
536+
losslog_adapater.validate_python(loss_log)
537+
otherlog_adapter.validate_python(other_log)
538+
539+
assert self.trainer().cur_step == step
540+
assert self.trainer().cur_epoch == epoch
541+
assert self.trainer().total_step == total_step
542+
assert self.trainer().total_epoch == total_epoch
543+
544+
self.count += 1
545+
546+
hooks_config = HooksConfig(
547+
after_save_dcp=[checkpoint_hook, CheckpointHook()],
548+
after_train_step=[train_step_hook, TrainStepHook()],
549+
after_save_hf=CheckpointHook(),
550+
after_save_snapshot=CheckpointHook(),
551+
)
552+
trainer = self._build_trainer(hooks_config)
553+
trainer.fit()
554+
555+
self.assertEqual(checkpoint_function_call_times, self.TOTAL_STEP // self.CHECKPOINT_INTERVAL)
556+
self.assertEqual(train_step_function_call_times, self.TOTAL_STEP)
557+
self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_TRAIN_STEP)[1].count, self.TOTAL_STEP)
558+
self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP)[1].count, self.TOTAL_STEP // self.CHECKPOINT_INTERVAL)
559+
self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_HF)[0].count, self.TOTAL_STEP // self.HF_INTERVAL)
560+
self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)[0].count, self.TOTAL_STEP // self.SNAPSHOT_INTERVAL)
561+
self._cleanup_trainer(trainer)

0 commit comments

Comments
 (0)