Skip to content

Commit 7302ba5

Browse files
Support MoE Router Training (#1207)
* init add enable_return_routed_experts * add config * update * fix * fix * support train * fix * fix lint * update * support ray transfer data (#5) * merge * add context_length --------- Co-authored-by: RunningLeon <[email protected]>
1 parent 531737f commit 7302ba5

File tree

14 files changed

+546
-44
lines changed

14 files changed

+546
-44
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import os
2+
from copy import deepcopy
3+
4+
from transformers import AutoTokenizer
5+
from xtuner.v1.config import (
6+
AdamWConfig,
7+
FSDPConfig,
8+
LRConfig,
9+
)
10+
from xtuner.v1.data_proto.rl_data import SampleParams
11+
from xtuner.v1.datasets import RLTokenizeFnConfig
12+
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
13+
from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config
14+
from xtuner.v1.ray.base import AcceleratorResourcesConfig
15+
from xtuner.v1.ray.config.worker import RolloutConfig
16+
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
17+
from xtuner.v1.ray.evaluator import EvaluatorConfig
18+
from xtuner.v1.ray.judger.controller import JudgerConfig
19+
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
20+
from xtuner.v1.rl.base import WorkerConfig
21+
from xtuner.v1.rl.grpo import GRPOLossConfig
22+
from xtuner.v1.train.rl_trainer import RLTrainerConfig
23+
24+
25+
work_dir = os.environ["WORK_DIR"]
26+
model_path = os.environ["MODEL_PATH"]
27+
data_path = os.environ["DATA_PATH"]
28+
eval_data_path = os.environ["EVAL_DATA_PATH"]
29+
enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0')
30+
enable_evaluate = True if eval_data_path != "" else False
31+
32+
# basic settings
33+
experimental_name = "dapo_math"
34+
total_epochs = 1
35+
global_batch_size = 512
36+
prompt_repeat_k = 16
37+
rollout_tp_size = 2
38+
rollout_ep_size = 1
39+
max_prompt_length = 2048
40+
max_response_length = 8192
41+
pack_max_length = 32768
42+
train_optimizer_steps = 16
43+
hf_interval = 50
44+
enable_initial_evaluate = True
45+
evaluate_step = 5
46+
47+
# 1. resources
48+
resources = AcceleratorResourcesConfig(
49+
accelerator="GPU",
50+
num_workers=8,
51+
num_cpus_per_worker=12,
52+
cpu_memory_per_worker=16 * 1024**3, # 16 GB
53+
)
54+
55+
# 2. rollout
56+
rollout_config = RolloutConfig(
57+
env=experimental_name,
58+
device=resources.accelerator,
59+
model_path=model_path,
60+
dtype="bfloat16",
61+
tensor_parallel_size=rollout_tp_size,
62+
expert_parallel_size=rollout_ep_size,
63+
gpu_memory_utilization=0.8,
64+
context_length = max_response_length + max_prompt_length,
65+
enable_return_routed_experts=True if enable_return_routed_experts == "1" else False,
66+
)
67+
68+
# sampling params
69+
training_sample_params = SampleParams(
70+
max_tokens=max_response_length,
71+
top_k=0,
72+
top_p=1.0,
73+
temperature=1.0,
74+
min_tokens=0,
75+
)
76+
evaluation_sample_params = deepcopy(training_sample_params)
77+
evaluation_sample_params.top_p = 0.7
78+
79+
# dataset
80+
train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path)
81+
eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None
82+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
83+
tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length)
84+
85+
train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}]
86+
eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else []
87+
88+
dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none")
89+
90+
# 3. judger
91+
from xtuner.v1.utils.rl_test_utils import get_eos_token
92+
eos_token_id = get_eos_token(model_path)
93+
eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id)
94+
dapomath_judger_config = DapoMathJudgerConfig(
95+
judger_name="dapo_math",
96+
eos_token=eos_token_str,
97+
enable_overlong_buffer = True,
98+
max_response_len =max_response_length,
99+
overlong_buffer_len=4096,
100+
overlong_penalty_factor=1.0,
101+
tokenizer=tokenizer)
102+
judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config])
103+
104+
# 4. dataflow and evaluator
105+
dataflow_config = DataFlowConfig(
106+
env=experimental_name,
107+
prompt_repeat_k=prompt_repeat_k,
108+
global_batch_size=global_batch_size,
109+
sample_params=training_sample_params,
110+
)
111+
112+
113+
def dapo_compute_metric(samples):
114+
return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)}
115+
116+
117+
evaluator_cfg = EvaluatorConfig(
118+
enable_evaluate=enable_evaluate,
119+
enable_initial_evaluate=enable_initial_evaluate,
120+
dataset_cfg=eval_dataset_cfg,
121+
tokenizer=tokenizer,
122+
evaluate_step=evaluate_step,
123+
compute_metric_func=dapo_compute_metric,
124+
sample_params=evaluation_sample_params,
125+
) if enable_evaluate else None
126+
127+
replay_buffer_cfg = ReplayBufferConfig(
128+
dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer
129+
)
130+
131+
# 5. Train worker
132+
model_cfg = Qwen3MoE30BA3Config()
133+
optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False)
134+
loss_cfg = GRPOLossConfig(
135+
policy_loss_cfg=dict(
136+
cliprange_high=0.28,
137+
cliprange_low=0.2,
138+
loss_type="vanilla",
139+
clip_ratio_c=10.0,
140+
log_prob_diff_min=-20.0,
141+
log_prob_diff_max=20.0,
142+
),
143+
ignore_idx=-100,
144+
use_kl_loss=False,
145+
kl_loss_coef=0.0,
146+
kl_loss_type="low_var_kl",
147+
mode="chunk",
148+
chunk_size=512,
149+
)
150+
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
151+
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
152+
train_worker_cfg: WorkerConfig = WorkerConfig(
153+
model_cfg=model_cfg,
154+
load_from=model_path,
155+
optim_cfg=optim_cfg,
156+
loss_cfg=loss_cfg,
157+
lr_cfg=lr_cfg,
158+
fsdp_cfg=fsdp_cfg,
159+
sp_size=1,
160+
optimizer_steps=train_optimizer_steps,
161+
pack_max_length=pack_max_length,
162+
)
163+
164+
# 6. RL Trainer
165+
trainer = RLTrainerConfig(
166+
load_from=model_path,
167+
resources=resources,
168+
rollout_config=rollout_config,
169+
dataflow_config=dataflow_config,
170+
judger_config=judger_cfg,
171+
replay_buffer_config=replay_buffer_cfg,
172+
evaluator_config=evaluator_cfg,
173+
train_worker_config=train_worker_cfg,
174+
tokenizer_path=model_path,
175+
work_dir=work_dir,
176+
total_epochs=total_epochs,
177+
hf_interval=hf_interval,
178+
)
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import os
2+
from copy import deepcopy
3+
4+
from transformers import AutoTokenizer
5+
from xtuner.v1.config import (
6+
AdamWConfig,
7+
FSDPConfig,
8+
LRConfig,
9+
)
10+
from xtuner.v1.data_proto.rl_data import SampleParams
11+
from xtuner.v1.datasets import RLTokenizeFnConfig
12+
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
13+
from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config
14+
from xtuner.v1.ray.base import AcceleratorResourcesConfig
15+
from xtuner.v1.ray.config.worker import RolloutConfig
16+
from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig
17+
from xtuner.v1.ray.evaluator import EvaluatorConfig
18+
from xtuner.v1.ray.judger.controller import JudgerConfig
19+
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
20+
from xtuner.v1.rl.base import WorkerConfig
21+
from xtuner.v1.rl.grpo import GRPOLossConfig
22+
from xtuner.v1.train.rl_trainer import RLTrainerConfig
23+
24+
25+
work_dir = os.environ["WORK_DIR"]
26+
model_path = os.environ["MODEL_PATH"]
27+
data_path = os.environ["DATA_PATH"]
28+
eval_data_path = os.environ["EVAL_DATA_PATH"]
29+
enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0')
30+
enable_evaluate = True if eval_data_path != "" else False
31+
32+
# basic settings
33+
experimental_name = "grpo_gsm8k"
34+
total_epochs = 15
35+
global_batch_size = 1024
36+
prompt_repeat_k = 5
37+
rollout_tp_size = 2
38+
rollout_ep_size = 1
39+
max_prompt_length = 512
40+
max_response_length = 1024
41+
pack_max_length = 32768
42+
train_optimizer_steps = 4
43+
hf_interval = 15
44+
enable_initial_evaluate = True
45+
evaluate_step = 10
46+
47+
# grpo quick test settings
48+
# total_epochs = 3
49+
# global_batch_size = 64
50+
# prompt_repeat_k = 5
51+
# rollout_tp_size = 1
52+
# rollout_ep_size = 1
53+
# max_prompt_length = 512
54+
# max_response_length = 1024
55+
# pack_max_length = 32768
56+
# train_optimizer_steps = 1
57+
# hf_interval = 100
58+
# enable_initial_evaluate = True
59+
# evaluate_step = 15
60+
61+
# 1. resources
62+
resources = AcceleratorResourcesConfig(
63+
accelerator="GPU",
64+
num_workers=8,
65+
num_cpus_per_worker=12,
66+
cpu_memory_per_worker=16 * 1024**3, # 16 GB
67+
)
68+
69+
# 2. rollout
70+
rollout_config = RolloutConfig(
71+
env=experimental_name,
72+
device=resources.accelerator,
73+
model_path=model_path,
74+
dtype="bfloat16",
75+
tensor_parallel_size=rollout_tp_size,
76+
expert_parallel_size=rollout_ep_size,
77+
gpu_memory_utilization=0.75,
78+
context_length = max_response_length + max_prompt_length,
79+
enable_return_routed_experts=True if enable_return_routed_experts == "1" else False,
80+
)
81+
82+
# sampling params
83+
training_sample_params = SampleParams(
84+
max_tokens=max_response_length,
85+
)
86+
evaluation_sample_params = deepcopy(training_sample_params)
87+
evaluation_sample_params.top_p = 1.0
88+
evaluation_sample_params.temperature = 0.0
89+
evaluation_sample_params.top_k = 1
90+
91+
# dataset: 不需要修改
92+
train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path)
93+
eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None
94+
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
95+
tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length)
96+
97+
train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}]
98+
eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else []
99+
100+
dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none")
101+
102+
# 3. judger
103+
dapomath_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")
104+
judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config])
105+
106+
# 4. dataflow and evaluator
107+
dataflow_config = DataFlowConfig(
108+
env=experimental_name,
109+
prompt_repeat_k=prompt_repeat_k,
110+
global_batch_size=global_batch_size,
111+
sample_params=training_sample_params,
112+
)
113+
114+
evaluator_cfg = EvaluatorConfig(
115+
enable_evaluate=enable_evaluate,
116+
enable_initial_evaluate=enable_initial_evaluate,
117+
dataset_cfg=eval_dataset_cfg,
118+
tokenizer=tokenizer,
119+
evaluate_step=evaluate_step,
120+
compute_metric_func=None,
121+
sample_params=evaluation_sample_params,
122+
) if enable_evaluate else None
123+
124+
# replay buffer config: : 不需要修改
125+
replay_buffer_cfg = ReplayBufferConfig(
126+
dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer
127+
)
128+
129+
# 5. Train worker
130+
# NOTE: modify model_cfg
131+
model_cfg = Qwen3MoE30BA3Config()
132+
optim_cfg = AdamWConfig(lr=1e-6, foreach=False)
133+
loss_cfg = GRPOLossConfig(
134+
policy_loss_cfg=dict(
135+
cliprange_high=0.2,
136+
cliprange_low=0.2,
137+
loss_type="vanilla",
138+
),
139+
ignore_idx=-100,
140+
use_kl_loss=True,
141+
kl_loss_coef=0.001,
142+
kl_loss_type="low_var_kl",
143+
mode="chunk",
144+
chunk_size=512,
145+
)
146+
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
147+
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
148+
train_worker_cfg: WorkerConfig = WorkerConfig(
149+
model_cfg=model_cfg,
150+
load_from=model_path,
151+
optim_cfg=optim_cfg,
152+
loss_cfg=loss_cfg,
153+
lr_cfg=lr_cfg,
154+
fsdp_cfg=fsdp_cfg,
155+
sp_size=1,
156+
optimizer_steps=train_optimizer_steps,
157+
pack_max_length=pack_max_length,
158+
)
159+
160+
# 6. RL Trainer
161+
trainer = RLTrainerConfig(
162+
load_from=model_path,
163+
resources=resources,
164+
rollout_config=rollout_config,
165+
dataflow_config=dataflow_config,
166+
judger_config=judger_cfg,
167+
replay_buffer_config=replay_buffer_cfg,
168+
evaluator_config=evaluator_cfg,
169+
train_worker_config=train_worker_cfg,
170+
tokenizer_path=model_path,
171+
work_dir=work_dir,
172+
total_epochs=total_epochs,
173+
hf_interval=hf_interval,
174+
)

0 commit comments

Comments
 (0)