|
3 | 3 | from pathlib import Path |
4 | 4 | from unittest.mock import patch, Mock |
5 | 5 | import pickle |
| 6 | +import shutil |
| 7 | +import weakref |
| 8 | +from pydantic import TypeAdapter |
6 | 9 |
|
7 | 10 | import torch |
8 | 11 | import torch.distributed as dist |
|
18 | 21 | InternVL3P5Dense8BConfig, |
19 | 22 | InternVL3P5MoE30BA3Config, |
20 | 23 | ) |
21 | | -from xtuner.v1.train.trainer import Trainer, ResumeConfig |
| 24 | +from xtuner.v1.train.trainer import HooksConfig, Trainer, ResumeConfig, HookStage |
22 | 25 | from xtuner.v1.datasets import FTDPTokenizeFnConfig |
23 | 26 | from xtuner.v1.datasets.sft_tokenize_fn import OpenaiTokenizeFunctionConfig |
24 | 27 | from xtuner.v1.train.trainer import TrainerConfig |
| 28 | +from xtuner.v1.engine.train_engine import LossLog, OtherLog |
25 | 29 | from xtuner.v1.loss import CELossConfig |
| 30 | +from xtuner._testing import DeterministicDDPTestCase |
26 | 31 | from unittest import TestCase |
27 | 32 |
|
28 | 33 | from xtuner.v1.utils.device import get_device |
@@ -440,3 +445,117 @@ def _dump_trainer_config(self, trainer_cfg: TrainerConfig): |
440 | 445 | trainer_cfg.model_dump_json() |
441 | 446 | trainer_cfg.model_dump() |
442 | 447 | pickle.dumps(trainer_cfg) |
| 448 | + |
| 449 | + |
| 450 | +class TestHooksConfig(DeterministicDDPTestCase): |
| 451 | + TOTAL_STEP = 10 |
| 452 | + CHECKPOINT_INTERVAL = 5 |
| 453 | + SNAPSHOT_INTERVAL = 2 |
| 454 | + HF_INTERVAL = 10 |
| 455 | + |
| 456 | + def _build_trainer(self, hooks_config: HooksConfig): |
| 457 | + model_cfg = Qwen3MoE30BA3Config(num_hidden_layers=2, hidden_size=1024, moe_intermediate_size=384) |
| 458 | + dataset_config = [ |
| 459 | + { |
| 460 | + "dataset": DatasetConfig(name="alpaca", anno_path=os.environ["ALPACA_PATH"], sample_ratio=1.0), |
| 461 | + "tokenize_fn": OpenaiTokenizeFunctionConfig( |
| 462 | + max_length=100, chat_template="qwen3" |
| 463 | + ), |
| 464 | + # "tokenize_fn": FTDPTokenizeFnConfig(max_length=16386), |
| 465 | + }, |
| 466 | + ] |
| 467 | + dataloader_config = DataloaderConfig(pack_max_length=100) |
| 468 | + |
| 469 | + optim_cfg = AdamWConfig(lr=0.1, weight_decay=0.1) |
| 470 | + lr_cfg = LRConfig(lr_type="cosine", lr_min=0.001, warmup_ratio=0.03) |
| 471 | + |
| 472 | + work_dir = tempfile.TemporaryDirectory().name |
| 473 | + if dist.get_rank() == 0: |
| 474 | + work_dir_list = [work_dir] |
| 475 | + else: |
| 476 | + work_dir_list = [None] |
| 477 | + dist.broadcast_object_list(work_dir_list, src=0) |
| 478 | + work_dir = work_dir_list[0] |
| 479 | + |
| 480 | + trainer_cfg = TrainerConfig( |
| 481 | + model_cfg=model_cfg, |
| 482 | + optim_cfg=optim_cfg, |
| 483 | + dataset_cfg=dataset_config, |
| 484 | + dataloader_cfg=dataloader_config, |
| 485 | + lr_cfg=lr_cfg, |
| 486 | + loss_cfg=CELossConfig(mode="chunk", chunk_size=1024), |
| 487 | + global_batch_size=self.world_size, |
| 488 | + sp_size=1, |
| 489 | + total_step=self.TOTAL_STEP, |
| 490 | + seed=42, |
| 491 | + checkpoint_interval=self.CHECKPOINT_INTERVAL, |
| 492 | + snapshot_interval=self.SNAPSHOT_INTERVAL, |
| 493 | + hf_interval=self.HF_INTERVAL, |
| 494 | + tokenizer_path=os.environ["QWEN3_MOE_PATH"], |
| 495 | + work_dir=work_dir, |
| 496 | + hooks_config=hooks_config, |
| 497 | + ) |
| 498 | + return Trainer.from_config(trainer_cfg) |
| 499 | + |
| 500 | + def _cleanup_trainer(self, trainer: Trainer): |
| 501 | + if dist.get_rank() == 0: |
| 502 | + shutil.rmtree(trainer.work_dir, ignore_errors=True) |
| 503 | + dist.barrier() |
| 504 | + |
| 505 | + def test_hooks_config(self): |
| 506 | + self.create_pg(DEVICE) |
| 507 | + checkpoint_function_call_times = 0 |
| 508 | + train_step_function_call_times = 0 |
| 509 | + losslog_adapater = TypeAdapter(LossLog) |
| 510 | + otherlog_adapter = TypeAdapter(OtherLog) |
| 511 | + |
| 512 | + def checkpoint_hook(checkpoint, step, epoch, total_step, total_epoch): |
| 513 | + nonlocal checkpoint_function_call_times |
| 514 | + checkpoint_function_call_times += 1 |
| 515 | + |
| 516 | + def train_step_hook(loss_log, other_log, step, epoch, total_step, total_epoch): |
| 517 | + nonlocal train_step_function_call_times |
| 518 | + train_step_function_call_times += 1 |
| 519 | + |
| 520 | + |
| 521 | + class CheckpointHook: |
| 522 | + def __init__(self) -> None: |
| 523 | + self.count = 0 |
| 524 | + |
| 525 | + def __call__(self, checkpoint, step, epoch, total_step, total_epoch): |
| 526 | + self.count += 1 |
| 527 | + |
| 528 | + class TrainStepHook: |
| 529 | + def connect_trainer(self, trainer: Trainer): |
| 530 | + self.trainer = weakref.ref(trainer) |
| 531 | + |
| 532 | + def __init__(self) -> None: |
| 533 | + self.count = 0 |
| 534 | + |
| 535 | + def __call__(self, loss_log, other_log, step, epoch, total_step, total_epoch): |
| 536 | + losslog_adapater.validate_python(loss_log) |
| 537 | + otherlog_adapter.validate_python(other_log) |
| 538 | + |
| 539 | + assert self.trainer().cur_step == step |
| 540 | + assert self.trainer().cur_epoch == epoch |
| 541 | + assert self.trainer().total_step == total_step |
| 542 | + assert self.trainer().total_epoch == total_epoch |
| 543 | + |
| 544 | + self.count += 1 |
| 545 | + |
| 546 | + hooks_config = HooksConfig( |
| 547 | + after_save_dcp=[checkpoint_hook, CheckpointHook()], |
| 548 | + after_train_step=[train_step_hook, TrainStepHook()], |
| 549 | + after_save_hf=CheckpointHook(), |
| 550 | + after_save_snapshot=CheckpointHook(), |
| 551 | + ) |
| 552 | + trainer = self._build_trainer(hooks_config) |
| 553 | + trainer.fit() |
| 554 | + |
| 555 | + self.assertEqual(checkpoint_function_call_times, self.TOTAL_STEP // self.CHECKPOINT_INTERVAL) |
| 556 | + self.assertEqual(train_step_function_call_times, self.TOTAL_STEP) |
| 557 | + self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_TRAIN_STEP)[1].count, self.TOTAL_STEP) |
| 558 | + self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_DCP)[1].count, self.TOTAL_STEP // self.CHECKPOINT_INTERVAL) |
| 559 | + self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_HF)[0].count, self.TOTAL_STEP // self.HF_INTERVAL) |
| 560 | + self.assertEqual(hooks_config.get_hooks(HookStage.AFTER_SAVE_SNAPSHOT)[0].count, self.TOTAL_STEP // self.SNAPSHOT_INTERVAL) |
| 561 | + self._cleanup_trainer(trainer) |
0 commit comments