Skip to content

Commit b39581a

Browse files
authored
[bugfix] clip eos_token in judger and fix dataflow retry logic (#1189)
* fix dapo_math judger and fix dataflow logic * fix dapo_math judger and fix dataflow logic * fix typo * rm data * add comments * fix comments * fix * fix * fix * fix ci env * fix
1 parent d381190 commit b39581a

File tree

16 files changed

+243
-61
lines changed

16 files changed

+243
-61
lines changed

ci/scripts/CI_ENV.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/bin/bash
22
export QWEN3_VL_MOE_PATH=${CI_SHARE_MODEL}/Qwen3-VL-30B-A3B-Instruct_MOE
3-
export QWEN3_VL_DENSE_PATH=${CI_SHARE_MODEL}/Qwen3-VL-8B-Instruct_DENSE
3+
export QWEN3_VL_DENSE_PATH=${CI_SHARE_MODEL}/Qwen3-VL-4B-Instruct
44
export INTERN_VL_1B_PATH=${CI_SHARE_MODEL}/InternVL3_5-1B-HF
55
export VIDEO_ROOT=${CI_SHARE_DATA}/images
66
export QWEN3_4B_PATH=${CI_SHARE_MODEL}/Qwen3-4B-Instruct-2507
@@ -16,6 +16,7 @@ export INTERNS1_DENSE_PATH=${CI_SHARE_MODEL}/intern-s1-mini
1616
export ROLLOUT_MODEL_PATH=${CI_SHARE_MODEL}/Qwen3-8B
1717
export ALPACA_PATH=${CI_SHARE_DATA}/alpaca
1818
export INTERNS1_DATA_META=${CI_SHARE_DATA}/vlm_ci_data.json
19+
export ROLLOUT_DAPO_DATA_PATH=${CI_SHARE_DATA}/rl_test_judger_dapo_math_data.jsonl
1920
export TORCH_ALLOW_TF32_CUBLAS_OVERRIDE=0
2021
export XTUNER_DETERMINISTIC=true
2122
export XTUNER_USE_LMDEPLOY=1

examples/v1/config/rl_qwen25_7B_dapo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,12 @@
8686
dataloader_config = DataloaderConfig(pack_max_length=pack_max_length, collator="fake_collator", pack_level="none")
8787

8888
# 3. judger
89+
from xtuner.v1.utils.rl_test_utils import get_eos_token
90+
eos_token_id = get_eos_token(model_path)
91+
eos_token_str = tokenizer.convert_ids_to_tokens(eos_token_id)
8992
dapomath_judger_config = DapoMathJudgerConfig(
90-
judger_name = "dapo_math",
93+
judger_name="dapo_math",
94+
eos_token=eos_token_str,
9195
enable_overlong_buffer = True,
9296
max_response_len =max_response_length,
9397
overlong_buffer_len=4096,

tests/ray/test_evaluator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import ray
44
from transformers import AutoTokenizer
55

6-
76
from xtuner.v1.ray.config.worker import RolloutConfig
87
from xtuner.v1.ray.judger.controller import JudgerConfig
98
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
@@ -110,6 +109,19 @@ def custom_compute_metric(samples):
110109
custom_correctness = ray.get(custom_evaluator.run.remote())
111110
self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy'])
112111

113-
112+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
113+
def test_lmdeploy_evaluator_with_failed_response(self):
114+
evaluator_cfg = EvaluatorConfig(
115+
dataset_cfg=self.eval_dataset_cfg,
116+
tokenizer=self.tokenizer,
117+
max_concurrent=1,
118+
eval_sample_ratio=1, # generate 5 samples
119+
sample_params=SampleParams(temperature=2.5), # invalid temperature to trigger error
120+
max_retry_times=1,
121+
)
122+
evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
123+
correctness = ray.get(evaluator.run.remote())
124+
self.assertEqual(len(correctness), 0)
125+
114126
if __name__ == '__main__':
115127
unittest.main()

tests/ray/test_judger.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
1818
DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
1919
VERL_ROLLOUT_DATA_PATH = os.environ["VERL_ROLLOUT_DATA_PATH"]
20+
DAPO_DATA_PATH = os.environ.get("ROLLOUT_DAPO_DATA_PATH")
2021

2122
FAKE_JUDGER_INPUT_ITEM = RLDataFlowItem(
2223
uid = RLUIDItem(action_id=uuid4().int,
@@ -67,6 +68,39 @@ def construct_judger_data(data_path):
6768
dataitem.append(data_item)
6869
return dataitem
6970

71+
def construct_dapo_judger_data(data_path):
72+
data_item_list = []
73+
save_reward = []
74+
with open(data_path, 'r', encoding='utf-8') as f:
75+
lines = f.readlines()
76+
for i in range(0, len(lines), 7):
77+
group = ''.join(lines[i:i+7]).strip()
78+
if group:
79+
try:
80+
item = json.loads(group)
81+
data_item = RLDataFlowItem(
82+
uid = RLUIDItem(
83+
action_id=uuid4().int,
84+
observation_id=uuid4().int
85+
),
86+
data = RLDatasetItem(
87+
messages=[{
88+
'role': 'user',
89+
'content': ""
90+
}],
91+
reward_model={"ground_truth": item["label"]},
92+
data_source={"dapo_math": 1.0}
93+
),
94+
env = RLEnvDataItem(
95+
rollout=RLRolloutResponseItem(response=item['response'])
96+
)
97+
)
98+
data_item_list.append(data_item)
99+
save_reward.append(item["reward"])
100+
except Exception as e:
101+
print(f"Error parsing group starting at line {i+12}: {e}")
102+
return data_item_list, save_reward
103+
70104
class TestJudgerController(unittest.TestCase):
71105

72106
@classmethod
@@ -101,6 +135,33 @@ def test_gsm8k_judger(self):
101135
self.assertEqual(res2[0].reward["score"], 1.0)
102136
self.assertEqual(res2[1].reward["score"], 1.0)
103137

138+
def test_dapo_judger(self):
139+
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
140+
from xtuner.v1.utils.rl_test_utils import get_eos_token_from_model_path
141+
from transformers import AutoTokenizer
142+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
143+
eos_token_str = get_eos_token_from_model_path(MODEL_PATH, tokenizer)
144+
145+
dapo_judger_config = DapoMathJudgerConfig(
146+
judger_name="dapo_math",
147+
eos_token=eos_token_str,
148+
enable_overlong_buffer=True,
149+
max_response_len=32768,
150+
overlong_buffer_len=4096,
151+
overlong_penalty_factor=1.0,
152+
tokenizer=tokenizer
153+
154+
)
155+
judger_cfg = JudgerConfig(
156+
reward_judger_configs=[dapo_judger_config]
157+
)
158+
judger_controller = JudgerController.remote(judger_cfg)
159+
judger_data, save_reward = construct_dapo_judger_data(DAPO_DATA_PATH)
160+
group_data = ray.get(judger_controller.run.remote(judger_data))
161+
reward = [data.reward["score"] for data in group_data]
162+
avg_score = np.mean(reward)
163+
self.assertLessEqual(float(np.abs(avg_score - np.mean(save_reward))), 0.001)
164+
104165
def test_gsm8k_multi_judger(self):
105166
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
106167
# 支持一个GSM8KJudgerConfig创建多个实例

tests/ray/test_rollout.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,30 @@ def setUp(self):
104104
def tearDown(self):
105105
ray.shutdown()
106106

107+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
108+
def test_lmdeploy_dataflow_with_failed_response(self):
109+
failed_dataflow_cfg = DataFlowConfig(
110+
env="test",
111+
max_concurrent=1,
112+
prompt_repeat_k=2,
113+
global_batch_size=1,
114+
enable_partial_rollout=0,
115+
max_retry_times=1,
116+
)
117+
self.test_env = SingleTurnEnvironment.remote(
118+
"test_env",
119+
self.pg,
120+
rollout_cfg=self.rollout_cfg,
121+
)
122+
self.test_flow = DataFlow.remote("test_env",
123+
failed_dataflow_cfg,
124+
self.replay_buffer_cfg,
125+
self.test_env
126+
)
127+
sample_params = SampleParams(temperature=2.5) # invalid temperature to trigger error
128+
responses = ray.get(self.test_flow.run.remote(num=1, sample_params=sample_params), timeout=300)
129+
self.assertEqual(len(responses),0)
130+
107131
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
108132
def test_lmdeploy_generate(self):
109133
sample_params = SampleParams(temperature=0.0)

xtuner/v1/ray/base/accelerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ def __init__(self, **kwargs):
7272
available_memory = available_resources.get("memory", 0)
7373
available_gpus = available_resources.get("GPU", 0)
7474

75-
assert kwargs["num_workers"] <= available_gpus, "Not enough available GPUS in Ray cluster."
75+
assert kwargs["num_workers"] <= available_gpus, (
76+
f"Not enough available GPUS in Ray cluster, available_gpus is {available_gpus} but xtuner needs {kwargs['num_workers']}."
77+
)
7678
# TODO: manage single controller's cpu resource to replace "10" here
7779
assert (kwargs["num_cpus_per_worker"] * kwargs["num_workers"]) + 10 <= available_cpus, (
7880
f"Not enough available CPUs in Ray cluster, available_cpus is {available_cpus} but xtuner needs {kwargs['num_cpus_per_worker'] * kwargs['num_workers'] + 10}."

xtuner/v1/ray/dataflow/flow.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tqdm.auto import tqdm
99
from typing_extensions import Annotated
1010

11-
from xtuner.v1.data_proto.rl_data import RLDataFlowItem
11+
from xtuner.v1.data_proto.rl_data import RLDataFlowItem, check_dataflow_item
1212
from xtuner.v1.ray.environment import SingleTurnEnvironment
1313
from xtuner.v1.ray.rollout.controller import SampleParams
1414
from xtuner.v1.ray.utils import create_task
@@ -141,29 +141,35 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
141141
Optional[List[RLDataFlowItem]]: The group of samples if the task
142142
fails and needs to be retried, otherwise None.
143143
"""
144-
if group_samples_for_retry is not None:
145-
for data_item in group_samples_for_retry:
146-
data_item.extra_info.retry_times += 1
147-
148-
group_data_items = group_samples_for_retry
149144
try:
150-
# 该函数中所有的数据结构都是RLDataFlowItem
151145
# step 1: sample
152-
with timer("sample", self.timer_dict):
153-
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
154-
self.env,
155-
self.config.enable_partial_rollout,
156-
self.config.prompt_repeat_k,
157-
)
158-
self.send_samples_count += 1
159-
self.logger.debug(
160-
f"[ROLLOUT] Get 1 sample and dataflow have sent {self.send_samples_count} to rollout_controller"
161-
)
146+
# TODO(@duanyanhui): More fine-grained control over group data generation:
147+
# Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency
148+
# Resend only the failed prompts in a group when retrying worker_task to avoid wasted computation resources."
149+
if group_samples_for_retry is None or len(group_samples_for_retry) == 0:
150+
with timer("sample", self.timer_dict):
151+
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
152+
self.env,
153+
self.config.enable_partial_rollout,
154+
self.config.prompt_repeat_k,
155+
)
156+
self.send_samples_count += 1
157+
self.logger.debug(
158+
f"[ROLLOUT] Get 1 sample and dataflow have sent {self.send_samples_count} to rollout_controller"
159+
)
160+
else:
161+
group_data_items = group_samples_for_retry
162+
for data_item in group_samples_for_retry:
163+
data_item.extra_info.retry_times += 1
164+
162165
# step 2: env generate
163166
with timer("generate", self.timer_dict):
164167
group_data_items = await self.env_controller.run.remote( # type: ignore[attr-defined]
165168
group_data_items, sample_params=self.sample_params, extra_params=self.extra_params
166169
)
170+
# 需要在这里处理check_dataflow_item,因为要保留group_data_items的data信息,作为retry的输入
171+
if not check_dataflow_item(group_data_items):
172+
return group_data_items
167173

168174
# step 3: filter
169175
with timer("post_process", self.timer_dict):
@@ -175,8 +181,6 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
175181

176182
except Exception as e:
177183
self.logger.error(f"Worker task failed with exception: {e}. Returning meta for retry.", exc_info=True)
178-
for sample in group_data_items: # type: ignore[union-attr]
179-
sample.extra_info.retry_times += 1
180184
return group_data_items
181185

182186
async def concurrent_task_runner(self):
@@ -204,7 +208,10 @@ async def concurrent_task_runner(self):
204208
with tqdm(total=self.target_batch_size, desc="rollout_controller for training samples") as pbar:
205209
update_step = max(1, int(self.target_batch_size * 0.1))
206210
next_update_threshold = update_step
207-
while self.finished_samples_count < self.target_batch_size:
211+
while (
212+
self.finished_samples_count < self.target_batch_size
213+
and self.failed_samples_count < self.target_batch_size
214+
):
208215
if self.finished_samples_count >= next_update_threshold:
209216
pbar.n = self.finished_samples_count
210217
pbar.refresh()
@@ -227,27 +234,36 @@ async def concurrent_task_runner(self):
227234
if result is not None:
228235
if result[0].extra_info.retry_times < self.config.max_retry_times:
229236
# If the retry count is less than max_retry_times, retry the task
237+
self.logger.info(
238+
f"Retrying task for {result[0].data}. Retry count: {result[0].extra_info.retry_times}"
239+
)
230240
retry_task = create_task(self.worker_task(group_samples_for_retry=result))
231241
pending_tasks.add(retry_task)
232242
else:
233-
self.logger.error(f"Max retry reached for {result[0]['prompt_id']}. Not retrying.")
234243
self.failed_samples_count += 1
235-
244+
self.logger.error(
245+
f"Max retry reached for {result[0].data}. Not retrying. Current failed count: {self.failed_samples_count}"
246+
)
236247
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
237248
waiting_tasks = pending_tasks
238249

239250
pbar.n = self.finished_samples_count
240251
pbar.refresh()
241252

242-
self.logger.info("Target batch size reached. Pausing env controller.")
253+
if self.finished_samples_count == self.target_batch_size:
254+
self.logger.info("Target batch size reached. Pausing env controller.")
255+
if self.failed_samples_count == self.target_batch_size:
256+
self.logger.info("Max failed samples reached. Pausing env controller.")
257+
243258
ray.get(self.env_controller.pause.remote())
244259

245260
if waiting_tasks:
246261
await asyncio.wait_for(asyncio.gather(*waiting_tasks, return_exceptions=True), timeout=10)
247262

248-
self.unfinished_samples_count = ray.get(self.replay_buffer.get_unfinished_samples.remote())
249-
self.logging_replaybuffer_state()
250-
self.logging_timing_perf()
263+
if self.finished_samples_count == self.target_batch_size:
264+
self.unfinished_samples_count = ray.get(self.replay_buffer.get_unfinished_samples.remote())
265+
self.logging_replaybuffer_state()
266+
self.logging_timing_perf()
251267

252268
async def run(
253269
self,

xtuner/v1/ray/dataflow/replay_buffer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def get(self, global_batch_size: int) -> List[List[RLDataFlowItem]]:
318318
"""
319319
samples = []
320320
if len(self._returned) < global_batch_size:
321-
raise ValueError("Not enough finished samples in replay buffer")
321+
self.logger.error("Not enough finished samples in replay buffer")
322322
return []
323323
else:
324324
target_finished_list = self._returned[:global_batch_size]

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from xtuner.v1.data_proto.rl_data import (
88
RLDataFlowItem,
99
RLJudgerResponseItem,
10-
check_dataflow_item,
1110
update_dataflow_item,
1211
)
1312
from xtuner.v1.ray.environment.base_env import BaseEnvironment
@@ -92,6 +91,4 @@ async def run(
9291
if self.judger_controller:
9392
judger_responses: RLJudgerResponseItem = await self.judger_controller.run.remote(group_data_items)
9493
group_data_items = update_dataflow_item(group_data_items, "env.judger", judger_responses)
95-
if not check_dataflow_item(group_data_items):
96-
return []
9794
return group_data_items

0 commit comments

Comments
 (0)