|
47 | 47 | log_format, |
48 | 48 | record_git_info, |
49 | 49 | ) |
| 50 | +from xtuner.v1.utils.check_health import check_health |
50 | 51 | from xtuner.v1.utils.device import get_device, get_torch_device_module |
51 | 52 |
|
52 | 53 | from .toy_tokenizer import UTF8ByteTokenizer |
@@ -172,6 +173,7 @@ class TrainerConfig(BaseModel): |
172 | 173 | checkpoint_maxkeep: int | None = -1 |
173 | 174 | skip_checkpoint_validation: bool = False # Suggest enabled if fsdp_size is larger than 512 |
174 | 175 | snapshot_interval: int | None = None |
| 176 | + check_health_interval: int | None = None |
175 | 177 | hf_interval: int | None = None |
176 | 178 | hf_max_keep: int | None = None |
177 | 179 | exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl" |
@@ -289,6 +291,7 @@ def __init__( |
289 | 291 | checkpoint_maxkeep: int | None = -1, |
290 | 292 | skip_checkpoint_validation: bool = False, # Suggest enabled if fsdp_size is larger than 512 |
291 | 293 | snapshot_interval: int | None = None, |
| 294 | + check_health_interval: int | None = None, |
292 | 295 | hf_interval: int | None = None, |
293 | 296 | hf_max_keep: int | None = None, |
294 | 297 | exp_tracker: Literal["tensorboard", "jsonl"] = "jsonl", |
@@ -337,6 +340,7 @@ def __init__( |
337 | 340 | self._checkpoint_interval = checkpoint_interval |
338 | 341 | self._checkpoint_maxkeep = checkpoint_maxkeep |
339 | 342 | self._snapshot_interval = snapshot_interval |
| 343 | + self._check_health_interval = check_health_interval |
340 | 344 | self._hf_max_keep = hf_max_keep |
341 | 345 | self._hf_interval = hf_interval |
342 | 346 |
|
@@ -481,6 +485,7 @@ def from_config(cls, config: TrainerConfig) -> Self: |
481 | 485 | checkpoint_maxkeep=config.checkpoint_maxkeep, |
482 | 486 | skip_checkpoint_validation=config.skip_checkpoint_validation, |
483 | 487 | snapshot_interval=config.snapshot_interval, |
| 488 | + check_health_interval=config.check_health_interval, |
484 | 489 | hf_interval=config.hf_interval, |
485 | 490 | hf_max_keep=config.hf_max_keep, |
486 | 491 | exp_tracker=config.exp_tracker, |
@@ -586,6 +591,7 @@ def fit(self): |
586 | 591 | ) |
587 | 592 |
|
588 | 593 | self._lr_scheduler.step() |
| 594 | + self._maybe_check_health() |
589 | 595 | self._maybe_save_hf() |
590 | 596 | ckpt_saved = self._maybe_save(is_snapshot=False) |
591 | 597 | if not ckpt_saved: |
@@ -806,6 +812,16 @@ def warmup_fn(x): |
806 | 812 | ) |
807 | 813 | return lr_scheduler |
808 | 814 |
|
| 815 | + def _maybe_check_health(self): |
| 816 | + if ( |
| 817 | + (self._check_health_interval is not None and self.cur_step % self._check_health_interval == 0) |
| 818 | + or (self._checkpoint_interval is not None and self.cur_step % self._checkpoint_interval == 0) |
| 819 | + or (self._snapshot_interval is not None and self.cur_step % self._snapshot_interval == 0) |
| 820 | + ): |
| 821 | + if not check_health(): |
| 822 | + raise RuntimeError("Health check failed, exit training") |
| 823 | + logger.info(f"Health check passed at step {self.cur_step}") |
| 824 | + |
809 | 825 | def _maybe_save(self, is_snapshot: bool = False) -> bool: |
810 | 826 | ckp_interval = self._checkpoint_interval if not is_snapshot else self._snapshot_interval |
811 | 827 | if ckp_interval is None: |
|
0 commit comments