Skip to content

Commit 84d55ac

Browse files
authored
Merge branch 'main' into refactor-http
2 parents 8a0d8bf + 3851e5f commit 84d55ac

File tree

13 files changed

+255
-36
lines changed

13 files changed

+255
-36
lines changed

.github/workflows/unit_test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ on:
1010
env:
1111
WORKSPACE_PREFIX: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-5)
1212
WORKSPACE_PREFIX_SHORT: $(echo $GITHUB_WORKSPACE |cut -d '/' -f 1-3)
13-
IMAGE: registry.h.pjlab.org.cn/ailab-llmrazor/xtuner:pt28_20250911_6652194_fix_pip
13+
IMAGE: registry.h.pjlab.org.cn/ailab-llmrazor/xtuner:pt28_20251113_22badb0_grouped_router_topk1
1414

1515
concurrency:
1616
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
@@ -30,4 +30,4 @@ jobs:
3030
- name: unit-test
3131
run: |
3232
export PYTHONPYCACHEPREFIX=/tmp
33-
python ci/scripts/xtuner_unittest.py "$IMAGE" "source ${{env.WORKSPACE_PREFIX}}/BASE_ENV.sh;source ci/scripts/CI_ENV.sh" "pytest tests --ignore=./tests/module/dispatcher/test_deepep.py"
33+
python ci/scripts/xtuner_unittest.py "$IMAGE" "source ${{env.WORKSPACE_PREFIX}}/BASE_ENV.sh;source ci/scripts/CI_ENV.sh" "pytest tests"

ci/scripts/CI_ENV.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,5 @@ export PYTEST_ADDOPTS='-o cache_dir=/tmp/.pytest_cache'
2828

2929
proxy_off
3030
pip install -e .
31-
pip install openai-harmony
32-
pip install numpy==1.26.4
3331

3432
export PYTHONPATH=${LM_DEPLOY}:$PYTHONPATH

tests/ray/test_update_weight.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
import os
2+
import unittest
3+
4+
import ray
5+
6+
from xtuner.v1.ray.base import AutoAcceleratorWorkers
7+
from xtuner.v1.ray.rollout import RolloutController
8+
from xtuner.v1.data_proto.rl_data import SampleParams
9+
from xtuner.v1.config import (
10+
AdamWConfig,
11+
FSDPConfig,
12+
LRConfig,
13+
)
14+
from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig
15+
from xtuner.v1.ray.config.worker import RolloutConfig
16+
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
17+
from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker
18+
from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig
19+
from xtuner.v1.model import get_model_config_from_hf
20+
21+
TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}]
22+
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
23+
24+
class TestUpdateWeight(unittest.TestCase):
25+
def setUp(self):
26+
os.environ["XTUNER_USE_FA3"] = "1"
27+
ray.init(num_cpus=80, ignore_reinit_error=True)
28+
self.model_path = MODEL_PATH
29+
self.init_config()
30+
self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
31+
32+
def tearDown(self):
33+
ray.shutdown()
34+
del os.environ["XTUNER_USE_FA3"]
35+
36+
def init_config(self):
37+
self.resources_cfg = AcceleratorResourcesConfig(
38+
accelerator="GPU",
39+
num_workers=2,
40+
num_cpus_per_worker=16,
41+
cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB
42+
)
43+
self.rollout_cfg = RolloutConfig(
44+
env="test_rollout",
45+
model_path=MODEL_PATH,
46+
model_name=os.path.basename(MODEL_PATH).lower(),
47+
tokenizer_path=MODEL_PATH,
48+
rollout_cross_node_comm=False,
49+
tensor_parallel_size=2,
50+
expert_parallel_size=1,
51+
gpus_per_node=8, # gpu: 8, npu: 16
52+
dtype="bfloat16",
53+
skip_load_weights=True,
54+
context_length=256,
55+
)
56+
57+
# training config
58+
model_cfg = get_model_config_from_hf(model_path=MODEL_PATH)
59+
if hasattr(model_cfg, 'z_loss_cfg'):
60+
model_cfg.z_loss_cfg = ZLossConfig()
61+
if hasattr(model_cfg, 'balancing_loss_cfg'):
62+
model_cfg.balancing_loss_cfg = BalancingLossConfig()
63+
optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False)
64+
fsdp_cfg: FSDPConfig = FSDPConfig()
65+
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7)
66+
self.worker_cfg: WorkerConfig = WorkerConfig(
67+
model_cfg=model_cfg,
68+
optim_cfg=optim_cfg,
69+
loss_cfg=LossConfig(
70+
policy_loss_cfg=dict(
71+
cliprange_high=0.28,
72+
cliprange_low=0.2,
73+
loss_type="vanilla",
74+
),
75+
ignore_idx=-100,
76+
use_kl_loss=True,
77+
kl_loss_coef=0.001,
78+
kl_loss_type="low_var_kl",
79+
mode="eager"),
80+
lr_cfg=lr_cfg,
81+
fsdp_cfg=fsdp_cfg,
82+
load_from=MODEL_PATH,
83+
sp_size=1,
84+
pack_max_length=1024,
85+
)
86+
87+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
88+
def test_lmdeploy_update_weight_and_generate(self):
89+
# init train
90+
TrainingWorker = ray.remote(
91+
runtime_env={
92+
"env_vars": {
93+
"RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1",
94+
"RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1",
95+
}
96+
},
97+
)(BaseTrainingWorker)
98+
train_workers, _ = AutoAcceleratorWorkers.from_placement_group(
99+
TrainingWorker, self.worker_cfg, self.pg
100+
)
101+
futures = [ worker.test_all_reduce.remote() for worker in train_workers ]
102+
ray.get(futures)
103+
train_controller = TrainingController.remote(
104+
workers=train_workers,
105+
)
106+
ray.get(train_controller.__ray_ready__.remote())
107+
108+
# fixed sample params
109+
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
110+
111+
# init rollout_update
112+
rollout_controller = RolloutController.remote(
113+
self.rollout_cfg,
114+
self.pg,
115+
)
116+
info_dict = ray.get(rollout_controller.get_rollout_info.remote())
117+
ray.get(train_controller.update_rollout_info.remote(info_dict))
118+
119+
# update weights
120+
ray.get(rollout_controller.offload.remote())
121+
ray.get(rollout_controller.onload_weights.remote())
122+
ray.get(train_controller.offload.remote(["optimizer"]))
123+
ray.get(train_controller.update_weights.remote())
124+
ray.get(train_controller.offload.remote(["model"]))
125+
ray.get(rollout_controller.onload_kvcache.remote())
126+
127+
res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
128+
ray.get(rollout_controller.shutdown.remote(), timeout=60)
129+
130+
# init rollout_ref
131+
self.rollout_cfg.skip_load_weights = False
132+
rollout_controller_ref = RolloutController.remote(
133+
self.rollout_cfg,
134+
self.pg,
135+
)
136+
137+
res_ref = ray.get(rollout_controller_ref.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
138+
ray.get(rollout_controller_ref.shutdown.remote(), timeout=60)
139+
140+
self.assertEqual(res_update_weight.response, res_ref.response)
141+
142+
143+
if __name__ == "__main__":
144+
test_instance = TestUpdateWeight()
145+
test_instance.setUp()
146+
try:
147+
test_instance.test_lmdeploy_update_weight_and_generate()
148+
finally:
149+
test_instance.tearDown()

tests/train/test_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self):
4040

4141
self.model = model = nn.Linear(10, 10)
4242
self.optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
43+
self.has_freeze_params = False
4344

4445
def grad_accumulation_steps(self, *args, **kwargs):
4546
return 1

xtuner/v1/data_proto/sequence_context.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from dataclasses import dataclass
32
from typing import cast
43

54
import torch
@@ -9,7 +8,10 @@
98
from .utils import pad_to_multiple_of, split_for_sequence_parallel
109

1110

12-
@dataclass
11+
# Avoid using dataclass decorator here to get rid of extra ops called in pytorch 2.8 and above
12+
# The extra ops is introduced by function _apply_to_tensors in
13+
# https://github.com/pytorch/pytorch/blob/v2.8.0/torch/distributed/fsdp/_fully_shard/_fsdp_state.py
14+
# Due to dataclasses.replace is called in _apply_to_tensors that triggering SequenceContext.__init__
1315
class SequenceContext:
1416
"""Keyword arguments for Flash Attention with Compile.
1517
@@ -29,26 +31,25 @@ class SequenceContext:
2931
cu_seq_lens_k: torch.IntTensor
3032
max_length_q: torch.Tensor
3133
max_length_k: torch.Tensor
32-
num_padding: int = 0
33-
sequence_parallel_mesh: DeviceMesh | None = None
34-
block_table: torch.Tensor | None = None
35-
device: str | torch.device = "cpu" # TODO: 这个地方有点乱,到处是 device
36-
position_ids: torch.LongTensor | None = None
34+
num_padding: int
35+
sequence_parallel_mesh: DeviceMesh | None
36+
block_table: torch.Tensor | None
37+
device: str | torch.device # TODO: 这个地方有点乱,到处是 device
38+
position_ids: torch.LongTensor | None
3739

3840
# Intern-S1
39-
image_flags: torch.LongTensor | None = None
41+
image_flags: torch.LongTensor | None
4042
# Qwen3VL
41-
image_grid_thw: torch.Tensor | None = None
42-
deepstack_visual_embeds: list[torch.Tensor] | None = None
43-
visual_pos_masks: torch.Tensor | None = None
44-
43+
image_grid_thw: torch.Tensor | None
44+
deepstack_visual_embeds: list[torch.Tensor] | None
45+
visual_pos_masks: torch.Tensor | None
4546
# mllm model
46-
pixel_values: torch.FloatTensor | None = None
47-
inputs_embeds: torch.FloatTensor | None = None
48-
num_img_tokens: list[int] | None = None
47+
pixel_values: torch.FloatTensor | None
48+
inputs_embeds: torch.FloatTensor | None
49+
num_img_tokens: list[int] | None
4950

5051
# moe routed_experts
51-
rollout_routed_experts: torch.LongTensor | None = None
52+
rollout_routed_experts: torch.LongTensor | None
5253

5354
def __init__(
5455
self,

xtuner/v1/engine/train_engine.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,15 @@ def __init__(
135135
self.optimizer = self.build_optimizer(optim_cfg)
136136
self.intra_layer_micro_batch = intra_layer_micro_batch
137137
self._count = 0
138+
self.has_freeze_params = self.__has_freeze_params()
139+
140+
def __has_freeze_params(self) -> bool:
141+
has_freeze_params = False
142+
for param in self.model.parameters(recurse=True):
143+
if not param.requires_grad:
144+
has_freeze_params = True
145+
break
146+
return has_freeze_params
138147

139148
def build_model(self) -> BaseModel:
140149
with torch.device("meta"):
@@ -398,7 +407,7 @@ def save_dcp(
398407
if optimizer_dir is not None:
399408
optimizer_dir.mkdir(parents=True, exist_ok=True)
400409

401-
_options = StateDictOptions(cpu_offload=True, ignore_frozen_params=True)
410+
_options = StateDictOptions(cpu_offload=True, ignore_frozen_params=self.model_cfg.dcp_ignore_frozen_params)
402411
with profile_time_and_memory(f"[DCP Checkpoint to {model_dir}]"):
403412
model_state = get_model_state_dict(self.model, options=_options)
404413
dcp.save(
@@ -426,8 +435,13 @@ def load_dcp(
426435
Args:
427436
dcp_dir (str): The directory to load the model from.
428437
"""
429-
_load_options = StateDictOptions(cpu_offload=True, ignore_frozen_params=True)
430-
_set_options = StateDictOptions(cpu_offload=True, strict=True)
438+
_load_options = StateDictOptions(
439+
cpu_offload=True, ignore_frozen_params=self.model_cfg.dcp_ignore_frozen_params
440+
)
441+
if self.has_freeze_params:
442+
_set_options = StateDictOptions(cpu_offload=True, strict=False)
443+
else:
444+
_set_options = StateDictOptions(cpu_offload=True, strict=True)
431445
with profile_time_and_memory(f"[Load DCP Model from {model_dir}]"):
432446
shard_model_state_dict = get_model_state_dict(self.model, options=_load_options)
433447
# inplace state_dict

xtuner/v1/engine/vision_compose_train_engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@ class VisionComposeConfigProtocol(Protocol):
5959
projector_config: BaseModel
6060
text_config: TransformerConfig
6161

62+
freeze_vision: bool = False
63+
freeze_projector: bool = False
64+
freeze_language: bool = False
65+
dcp_ignore_frozen_params: bool = True
66+
6267
def build(self) -> VisionComposeModelProtocol: ...
6368

6469
@property

xtuner/v1/model/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ class TransformerConfig(PydanticBaseModel):
7474
max_window_layers: Annotated[int | None, Parameter(group="model")] = None
7575
rope_scaling_cfg: RopeScalingConfig | None = None
7676
hf_save_worker: Annotated[int, Parameter(group="model")] = 16
77+
dcp_ignore_frozen_params: Annotated[bool, Parameter(group="model")] = False
7778

7879
@computed_field
7980
def num_attention_heads(self) -> int:
@@ -520,7 +521,7 @@ def _get_hf_params(
520521
_hf_key_list = all_hf_keys[start:end]
521522

522523
if not _hf_key_list:
523-
return [], []
524+
continue
524525

525526
hf_keys_list.append(_hf_key_list)
526527

@@ -552,14 +553,21 @@ def _get_hf_params(
552553
hf_tensor_list: list[torch.Tensor] = []
553554
# used in self._to_float8 to determine whether to convert a unshard hf_tensor to fp8
554555
fsdp_shard_tensor_list: list[torch.Tensor] = []
555-
for saved_tensor, load_spec, hf_keys in zip(saved_fused_tensor_list, spec_list, hf_keys_list):
556+
# `origin_tensor_list` is only used to mark, which tensors are float8 weights for the
557+
# `_to_float8` function
558+
origin_tensor_list: list[torch.Tensor] = []
559+
560+
for saved_tensor, load_spec, hf_keys, origin_tensor in zip(
561+
saved_fused_tensor_list, spec_list, hf_keys_list, tensor_list
562+
):
556563
dim = cast(int, load_spec.dim)
557564
hf_tensor_size = saved_tensor.shape[dim] / len(hf_keys)
558565
assert hf_tensor_size.is_integer(), "Internal Error, hf_tensor_size is not integer"
559566
hf_tensor_size = int(hf_tensor_size)
560567
hf_tensor = saved_tensor.split([hf_tensor_size] * len(hf_keys), dim=dim)
561568
hf_tensor_list.extend(hf_tensor)
562569
fsdp_shard_tensor_list.extend([saved_tensor] * len(hf_tensor))
570+
origin_tensor_list.extend([origin_tensor] * len(hf_tensor))
563571

564572
name_list = list(chain.from_iterable(hf_keys_list))
565573
hf_tensor_list = [
@@ -568,7 +576,7 @@ def _get_hf_params(
568576

569577
if dtype == torch.float8_e4m3fn:
570578
hf_tensor_list_new, name_list_new = self._to_float8(
571-
hf_tensor_list, name_list, fsdp_shard_tensor_list, dtype
579+
hf_tensor_list, name_list, origin_tensor_list, dtype
572580
)
573581
return hf_tensor_list_new, name_list_new
574582

xtuner/v1/model/compose/intern_s1/intern_s1_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class InternS1BaseConfig(BaseModel):
9797
freeze_projector: bool = False
9898
freeze_language: bool = False
9999
hf_save_worker: int = 16
100+
dcp_ignore_frozen_params: bool = True
100101

101102
def build(self) -> "InternS1ForConditionalGeneration":
102103
from .modeling_intern_s1 import InternS1ForConditionalGeneration

xtuner/v1/model/compose/internvl/internvl_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ class InternVLBaseConfig(BaseModel):
9393
freeze_projector: bool = False
9494
freeze_language: bool = False
9595
hf_save_worker: int = 16
96+
dcp_ignore_frozen_params: bool = True
9697

9798
def build(self) -> "InternVLForConditionalGeneration":
9899
from .modeling_internvl import InternVLForConditionalGeneration

0 commit comments

Comments
 (0)