|
| 1 | +import os |
| 2 | +from copy import deepcopy |
| 3 | + |
| 4 | +from transformers import AutoTokenizer |
| 5 | +from xtuner.v1.config import ( |
| 6 | + AdamWConfig, |
| 7 | + FSDPConfig, |
| 8 | + LRConfig, |
| 9 | +) |
| 10 | +from xtuner.v1.data_proto.rl_data import SampleParams |
| 11 | +from xtuner.v1.datasets import RLTokenizeFnConfig |
| 12 | +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig |
| 13 | +from xtuner.v1.model.moe.qwen3 import Qwen3MoE30BA3Config |
| 14 | +from xtuner.v1.ray.base import AcceleratorResourcesConfig |
| 15 | +from xtuner.v1.ray.config.worker import RolloutConfig |
| 16 | +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig |
| 17 | +from xtuner.v1.ray.evaluator import EvaluatorConfig |
| 18 | +from xtuner.v1.ray.judger.controller import JudgerConfig |
| 19 | +from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig |
| 20 | +from xtuner.v1.rl.base import WorkerConfig |
| 21 | +from xtuner.v1.rl.grpo import GRPOLossConfig |
| 22 | +from xtuner.v1.train.rl_trainer import RLTrainerConfig |
| 23 | + |
| 24 | + |
| 25 | +work_dir = os.environ["WORK_DIR"] |
| 26 | +model_path = os.environ["MODEL_PATH"] |
| 27 | +data_path = os.environ["DATA_PATH"] |
| 28 | +eval_data_path = os.environ["EVAL_DATA_PATH"] |
| 29 | +enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", '0') |
| 30 | +enable_evaluate = True if eval_data_path != "" else False |
| 31 | + |
| 32 | +# basic settings |
| 33 | +experimental_name = "dapo_math" |
| 34 | +total_epochs = 1 |
| 35 | +global_batch_size = 512 |
| 36 | +prompt_repeat_k = 16 |
| 37 | +rollout_tp_size = 2 |
| 38 | +rollout_ep_size = 1 |
| 39 | +max_prompt_length = 2048 |
| 40 | +max_response_length = 8192 |
| 41 | +pack_max_length = 32768 |
| 42 | +train_optimizer_steps = 16 |
| 43 | +hf_interval = 50 |
| 44 | +enable_initial_evaluate = True |
| 45 | +evaluate_step = 5 |
| 46 | + |
| 47 | +# 1. resources |
| 48 | +resources = AcceleratorResourcesConfig( |
| 49 | + accelerator="GPU", |
| 50 | + num_workers=8, |
| 51 | + num_cpus_per_worker=12, |
| 52 | + cpu_memory_per_worker=16 * 1024**3, # 16 GB |
| 53 | +) |
| 54 | + |
| 55 | +# 2. rollout |
| 56 | +rollout_config = RolloutConfig( |
| 57 | + env=experimental_name, |
| 58 | + device=resources.accelerator, |
| 59 | + model_path=model_path, |
| 60 | + dtype="bfloat16", |
| 61 | + tensor_parallel_size=rollout_tp_size, |
| 62 | + expert_parallel_size=rollout_ep_size, |
| 63 | + gpu_memory_utilization=0.8, |
| 64 | + context_length = max_response_length + max_prompt_length, |
| 65 | + enable_return_routed_experts=True if enable_return_routed_experts == "1" else False, |
| 66 | +) |
| 67 | + |
| 68 | +# sampling params |
| 69 | +training_sample_params = SampleParams( |
| 70 | + max_tokens=max_response_length, |
| 71 | + top_k=0, |
| 72 | + top_p=1.0, |
| 73 | + temperature=1.0, |
| 74 | + min_tokens=0, |
| 75 | +) |
| 76 | +evaluation_sample_params = deepcopy(training_sample_params) |
| 77 | +evaluation_sample_params.top_p = 0.7 |
| 78 | + |
| 79 | +# dataset |
| 80 | +train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path) |
| 81 | +eval_dataset = DatasetConfig(name=experimental_name, anno_path=eval_data_path) if enable_evaluate else None |
| 82 | +tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) |
| 83 | +tokenizer_config = RLTokenizeFnConfig(max_length=max_prompt_length) |
| 84 | + |
| 85 | +train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}] |
| 86 | +eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}] if enable_evaluate else [] |
| 87 | + |
| 88 | +dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none") |
| 89 | + |
| 90 | +# 3. judger |
| 91 | +from xtuner.v1.utils.rl_test_utils import get_eos_token |
| 92 | +eos_token_id = get_eos_token(model_path) |
| 93 | +eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id) |
| 94 | +dapomath_judger_config = DapoMathJudgerConfig( |
| 95 | + judger_name="dapo_math", |
| 96 | + eos_token=eos_token_str, |
| 97 | + enable_overlong_buffer = True, |
| 98 | + max_response_len =max_response_length, |
| 99 | + overlong_buffer_len=4096, |
| 100 | + overlong_penalty_factor=1.0, |
| 101 | + tokenizer=tokenizer) |
| 102 | +judger_cfg = JudgerConfig(reward_judger_configs=[dapomath_judger_config]) |
| 103 | + |
| 104 | +# 4. dataflow and evaluator |
| 105 | +dataflow_config = DataFlowConfig( |
| 106 | + env=experimental_name, |
| 107 | + prompt_repeat_k=prompt_repeat_k, |
| 108 | + global_batch_size=global_batch_size, |
| 109 | + sample_params=training_sample_params, |
| 110 | +) |
| 111 | + |
| 112 | + |
| 113 | +def dapo_compute_metric(samples): |
| 114 | + return {"accuracy": sum(s.env.judger.reward["acc"] > 0 for s in samples) / len(samples)} |
| 115 | + |
| 116 | + |
| 117 | +evaluator_cfg = EvaluatorConfig( |
| 118 | + enable_evaluate=enable_evaluate, |
| 119 | + enable_initial_evaluate=enable_initial_evaluate, |
| 120 | + dataset_cfg=eval_dataset_cfg, |
| 121 | + tokenizer=tokenizer, |
| 122 | + evaluate_step=evaluate_step, |
| 123 | + compute_metric_func=dapo_compute_metric, |
| 124 | + sample_params=evaluation_sample_params, |
| 125 | +) if enable_evaluate else None |
| 126 | + |
| 127 | +replay_buffer_cfg = ReplayBufferConfig( |
| 128 | + dataset_cfg=train_dataset_cfg, dataloader_cfg=dataloader_config, tokenizer=tokenizer |
| 129 | +) |
| 130 | + |
| 131 | +# 5. Train worker |
| 132 | +model_cfg = Qwen3MoE30BA3Config() |
| 133 | +optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) |
| 134 | +loss_cfg = GRPOLossConfig( |
| 135 | + policy_loss_cfg=dict( |
| 136 | + cliprange_high=0.28, |
| 137 | + cliprange_low=0.2, |
| 138 | + loss_type="vanilla", |
| 139 | + clip_ratio_c=10.0, |
| 140 | + log_prob_diff_min=-20.0, |
| 141 | + log_prob_diff_max=20.0, |
| 142 | + ), |
| 143 | + ignore_idx=-100, |
| 144 | + use_kl_loss=False, |
| 145 | + kl_loss_coef=0.0, |
| 146 | + kl_loss_type="low_var_kl", |
| 147 | + mode="chunk", |
| 148 | + chunk_size=512, |
| 149 | +) |
| 150 | +lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) |
| 151 | +fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) |
| 152 | +train_worker_cfg: WorkerConfig = WorkerConfig( |
| 153 | + model_cfg=model_cfg, |
| 154 | + load_from=model_path, |
| 155 | + optim_cfg=optim_cfg, |
| 156 | + loss_cfg=loss_cfg, |
| 157 | + lr_cfg=lr_cfg, |
| 158 | + fsdp_cfg=fsdp_cfg, |
| 159 | + sp_size=1, |
| 160 | + optimizer_steps=train_optimizer_steps, |
| 161 | + pack_max_length=pack_max_length, |
| 162 | +) |
| 163 | + |
| 164 | +# 6. RL Trainer |
| 165 | +trainer = RLTrainerConfig( |
| 166 | + load_from=model_path, |
| 167 | + resources=resources, |
| 168 | + rollout_config=rollout_config, |
| 169 | + dataflow_config=dataflow_config, |
| 170 | + judger_config=judger_cfg, |
| 171 | + replay_buffer_config=replay_buffer_cfg, |
| 172 | + evaluator_config=evaluator_cfg, |
| 173 | + train_worker_config=train_worker_cfg, |
| 174 | + tokenizer_path=model_path, |
| 175 | + work_dir=work_dir, |
| 176 | + total_epochs=total_epochs, |
| 177 | + hf_interval=hf_interval, |
| 178 | +) |
0 commit comments