Skip to content

Commit d3887ce

Browse files
committed
[CI] Update unittest for hook
1 parent 8b77c21 commit d3887ce

File tree

1 file changed

+172
-1
lines changed

1 file changed

+172
-1
lines changed

tests/train/test_trainer.py

Lines changed: 172 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from pathlib import Path
44
from unittest.mock import patch, Mock
55
import pickle
6+
import shutil
7+
import weakref
8+
from pydantic import TypeAdapter
69

710
import torch
811
import torch.distributed as dist
@@ -18,11 +21,13 @@
1821
InternVL3P5Dense8BConfig,
1922
InternVL3P5MoE30BA3Config,
2023
)
21-
from xtuner.v1.train.trainer import Trainer, ResumeConfig
24+
from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage
2225
from xtuner.v1.datasets import FTDPTokenizeFnConfig
2326
from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig
2427
from xtuner.v1.train.trainer import TrainerConfig
28+
from xtuner.v1.engine.train_engine import LossLog, OtherLog
2529
from xtuner.v1.loss import CELossConfig
30+
from xtuner._testing import DeterministicDDPTestCase
2631
from unittest import TestCase
2732

2833
from xtuner.v1.utils.device import get_device
@@ -441,3 +446,169 @@ def _dump_trainer_config(self, trainer_cfg: TrainerConfig):
441446
trainer_cfg.model_dump_json()
442447
trainer_cfg.model_dump()
443448
pickle.dumps(trainer_cfg)
449+
450+
451+
class CheckpointHookPickle:
452+
def __init__(self) -> None:
453+
self.count = 0
454+
455+
def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
456+
self.count += 1
457+
458+
459+
class TestHooksConfig(DeterministicDDPTestCase):
460+
TOTAL_STEP = 10
461+
CHECKPOINT_INTERVAL = 5
462+
SNAPSHOT_INTERVAL = 2
463+
HF_INTERVAL = 10
464+
ERROR_MESG_PREFIX="[HooksConfig Test Failed]: "
465+
466+
def _build_trainer(self, hooks_config: HooksConfig):
467+
model_cfg = Qwen3MoE30BA3Config(num_hidden_layers=2, hidden_size=1024, moe_intermediate_size=384)
468+
dataset_config = [
469+
{
470+
"dataset": DatasetConfig(name="alpaca", anno_path=os.environ["ALPACA_PATH"], sample_ratio=1.0),
471+
"tokenize_fn": OpenaiTokenizeFunctionConfig(
472+
max_length=100, chat_template="qwen3"
473+
),
474+
# "tokenize_fn": FTDPTokenizeFnConfig(max_length=16386),
475+
},
476+
]
477+
dataloader_config = DataloaderConfig(pack_max_length=100)
478+
479+
optim_cfg = AdamWConfig(lr=0.1, weight_decay=0.1)
480+
lr_cfg = LRConfig(lr_type="cosine", lr_min=0.001, warmup_ratio=0.03)
481+
482+
work_dir = tempfile.TemporaryDirectory().name
483+
if dist.get_rank() == 0:
484+
work_dir_list = [work_dir]
485+
else:
486+
work_dir_list = [None]
487+
dist.broadcast_object_list(work_dir_list, src=0)
488+
work_dir = work_dir_list[0]
489+
490+
trainer_cfg = TrainerConfig(
491+
model_cfg=model_cfg,
492+
optim_cfg=optim_cfg,
493+
dataset_cfg=dataset_config,
494+
dataloader_cfg=dataloader_config,
495+
lr_cfg=lr_cfg,
496+
loss_cfg=CELossConfig(mode="chunk", chunk_size=1024),
497+
global_batch_size=self.world_size,
498+
sp_size=1,
499+
total_step=self.TOTAL_STEP,
500+
seed=42,
501+
checkpoint_interval=self.CHECKPOINT_INTERVAL,
502+
snapshot_interval=self.SNAPSHOT_INTERVAL,
503+
hf_interval=self.HF_INTERVAL,
504+
tokenizer_path=os.environ["QWEN3_MOE_PATH"],
505+
work_dir=work_dir,
506+
hooks_config=hooks_config,
507+
)
508+
return Trainer.from_config(trainer_cfg)
509+
510+
def _cleanup_trainer(self, trainer: Trainer):
511+
if dist.get_rank() == 0:
512+
shutil.rmtree(trainer.work_dir, ignore_errors=True)
513+
dist.barrier()
514+
515+
def test_hooks_config(self):
516+
self.create_pg(DEVICE)
517+
checkpoint_function_call_times = 0
518+
train_step_function_call_times = 0
519+
losslog_adapater = TypeAdapter(LossLog)
520+
otherlog_adapter = TypeAdapter(OtherLog)
521+
522+
def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch):
523+
nonlocal checkpoint_function_call_times
524+
checkpoint_function_call_times += 1
525+
526+
def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch):
527+
nonlocal train_step_function_call_times
528+
train_step_function_call_times += 1
529+
530+
531+
class CheckpointHook:
532+
def __init__(self) -> None:
533+
self.count = 0
534+
535+
def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
536+
self.count += 1
537+
538+
class TrainStepHook:
539+
def connect_trainer(self, trainer: Trainer):
540+
self.trainer = weakref.ref(trainer)
541+
542+
def __init__(self) -> None:
543+
self.count = 0
544+
545+
def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch):
546+
losslog_adapater.validate_python(loss_log)
547+
otherlog_adapter.validate_python(other_log)
548+
549+
assert self.trainer().cur_step == step
550+
assert self.trainer().cur_epoch == epoch
551+
assert self.trainer().total_step == total_step
552+
assert self.trainer().total_epoch == total_epoch
553+
554+
self.count += 1
555+
556+
hooks_config = HooksConfig(
557+
after_save_dcp=[checkpoint_hook, CheckpointHook()],
558+
after_train_step=[train_step_hook, TrainStepHook()],
559+
after_save_hf=CheckpointHook(),
560+
after_save_snapshot=CheckpointHook(),
561+
)
562+
trainer = self._build_trainer(hooks_config)
563+
trainer.fit()
564+
565+
self.assertEqual(
566+
checkpoint_function_call_times,
567+
2,
568+
self.ERROR_MESG_PREFIX + "Checkpoint hook not called expected times",
569+
)
570+
self.assertEqual(
571+
train_step_function_call_times,
572+
10,
573+
self.ERROR_MESG_PREFIX + "Train step hook not called expected times",
574+
)
575+
self.assertEqual(
576+
hooks_config.get_hooks(HookStage.AFTER_TRAIN_STEP)[1].count,
577+
10,
578+
self.ERROR_MESG_PREFIX + "Train step hook not called expected times",
579+
)
580+
self.assertEqual(
581+
hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP)[1].count,
582+
2,
583+
self.ERROR_MESG_PREFIX + "Checkpoint hook not called expected times",
584+
)
585+
self.assertEqual(
586+
hooks_config.get_hooks(HookStage.AFTER_SAVE_HF)[0].count,
587+
1,
588+
self.ERROR_MESG_PREFIX + "HF checkpoint hook not called expected times",
589+
)
590+
# The last snapshot will not be saved fod dcp has been saved.
591+
self.assertEqual(
592+
hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)[0].count,
593+
4,
594+
self.ERROR_MESG_PREFIX + "Snapshot hook not called expected times",
595+
)
596+
self._cleanup_trainer(trainer)
597+
598+
def test_serialize_hooks_config(self):
599+
self.create_pg(DEVICE)
600+
class CheckpointHook:
601+
def __init__(self) -> None:
602+
self.count = 0
603+
604+
def __call__(self, checkpoint, step, epoch, total_step, total_epoch):
605+
self.count += 1
606+
607+
hooks_config = HooksConfig(
608+
after_train_step=CheckpointHook(),
609+
after_save_dcp=CheckpointHookPickle(),
610+
)
611+
dumped = pickle.dumps(hooks_config)
612+
loaded = pickle.loads(dumped)
613+
assert len(loaded.get_hooks(HookStage.AFTER_TRAIN_STEP)) == 0 # <local> object cannot be serialized
614+
assert len(loaded.get_hooks(HookStage.AFTER_SAVE_DCP)) == 1

0 commit comments

Comments
 (0)