-
Notifications
You must be signed in to change notification settings - Fork 387
[Feature] Add hooks in trainer #1249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
72c02e4 to
df75250
Compare
df75250 to
039a00b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a hooks system to the trainer, allowing users to inject custom logic at specific training stages. The hooks can be triggered after saving checkpoints (DCP, HF, snapshots) and after each training step.
Key changes:
- Introduced
HooksConfigwith four hook stages:after_save_dcp,after_save_hf,after_save_snapshot, andafter_train_step - Added Protocol definitions for type-safe hook implementations with optional
connect_trainermethod - Created
LossLogandOtherLogTypedDict classes to formalize logging data structures
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| xtuner/v1/train/trainer.py | Added hook Protocol definitions, HooksConfig class, hook execution logic in fit and save methods, new properties for cur_epoch and total_epoch |
| xtuner/v1/engine/train_engine.py | Defined LossLog and OtherLog TypedDict classes, updated train_step return type annotation |
| xtuner/v1/engine/vision_compose_train_engine.py | Updated train_step return type annotation and added type annotations for loss_log and other_log |
| xtuner/v1/engine/init.py | Exported LossLog and OtherLog types |
| xtuner/v1/rl/base/worker.py | Added type ignore comments for type checker compatibility |
| xtuner/_testing/testcase.py | Added create_pg method to set LOCAL_RANK environment variable |
| tests/train/test_trainer.py | Added comprehensive test for hooks functionality |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _validate_hooks( | ||
| cls, | ||
| value: list[HookProtocol] | HookProtocol | None, # noqa: E501 | ||
| ) -> list[Callable] | None: |
Copilot
AI
Nov 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normal methods should have 'self', rather than 'cls', as their first parameter.
039a00b to
d3887ce
Compare
|
fix lint |
1. **`after_save_dcp`** - Triggered by `Trainer._maybe_save()` (non-snapshot saves) - Executes after saving distributed checkpoint 2. **`after_save_snapshot`** - Triggered by `Trainer._maybe_save()` when saving snapshot - Executes after creating training snapshot 3. **`after_save_hf`** - Triggered by `Trainer._maybe_save_hf()` - Executes after saving in HuggingFace format 4. **`after_train_step`** - Triggered after `TrainEngine.train_step()` completes - Executes after each training step - Hooks are optional and backward-compatible - Each hook receives context (step number, paths, etc.) - No performance impact when not registered
d3887ce to
cf18fa5
Compare
No description provided.