Skip to content

Commit ecb3d66

Browse files
authored
[Refactor] refactor http request error (#1259)
* [Refactor] refactor http request error * add ut * fix ut * fix * fix ut * fix * fix * fix ut * fix comments * fix comments * fix comments * fix comments * fix ci * fix * fix * fix deactivate rollout worker * rm useless code * fix ci
1 parent 1911c71 commit ecb3d66

File tree

20 files changed

+591
-180
lines changed

20 files changed

+591
-180
lines changed

tests/ray/test_evaluator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,6 @@ def custom_compute_metric(samples):
107107
custom_evaluator = Evaluator.remote(custom_evaluator_cfg, self.test_env)
108108
custom_correctness = ray.get(custom_evaluator.run.remote())
109109
self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy'])
110-
111-
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
112-
def test_lmdeploy_evaluator_with_failed_response(self):
113-
evaluator_cfg = EvaluatorConfig(
114-
dataset_cfg=self.eval_dataset_cfg,
115-
tokenizer=self.tokenizer,
116-
max_concurrent=1,
117-
eval_sample_ratio=1, # generate 5 samples
118-
sample_params=SampleParams(temperature=2.5), # invalid temperature to trigger error
119-
max_retry_times=1,
120-
)
121-
evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
122-
correctness = ray.get(evaluator.run.remote())
123-
self.assertEqual(len(correctness), 0)
124110

125111
if __name__ == '__main__':
126112
unittest.main()

tests/ray/test_mock_rollout.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import os
2+
import unittest
3+
import ray
4+
from transformers import AutoTokenizer
5+
import torch
6+
import httpx
7+
from xtuner.v1.ray.config.worker import RolloutConfig
8+
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
9+
from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig
10+
from xtuner.v1.ray.environment import SingleTurnEnvironment
11+
from xtuner.v1.datasets import RLTokenizeFnConfig
12+
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
13+
from xtuner.v1.ray.rollout.controller import RolloutController
14+
from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker
15+
16+
MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
17+
TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
18+
resource_map = {"npu": "NPU", "cuda": "GPU"}
19+
@ray.remote
20+
class MockTimeoutRolloutController(RolloutController):
21+
def _get_worker_cls(self):
22+
return ray.remote(MockTimeoutRolloutWorker)
23+
24+
@ray.remote
25+
class MockRequestErrorRolloutController(RolloutController):
26+
def _get_worker_cls(self):
27+
return ray.remote(MockRequestErrorRolloutWorker)
28+
29+
@ray.remote
30+
class MockClientErrorRolloutController(RolloutController):
31+
def _get_worker_cls(self):
32+
return ray.remote(MockClientErrorRolloutWorker)
33+
34+
@ray.remote
35+
class MockServerErrorRolloutController(RolloutController):
36+
def _get_worker_cls(self):
37+
return ray.remote(MockServerErrorRolloutWorker)
38+
39+
def deactivate_worker_by_url(self, url):
40+
pass
41+
42+
class TestMockRollout(unittest.TestCase):
43+
@classmethod
44+
def setUpClass(cls):
45+
os.environ["XTUNER_USE_FA3"] = "1"
46+
47+
@classmethod
48+
def tearDownClass(cls):
49+
del os.environ["XTUNER_USE_FA3"]
50+
ray.shutdown()
51+
52+
def setUp(self):
53+
ray.init(num_cpus=80, ignore_reinit_error=True)
54+
self.global_batch_size = 3
55+
self.max_prompt_length = 4096
56+
self.max_response_length = 128
57+
self.max_concurrent = 3
58+
self.max_retry_times = 3
59+
60+
self.resources_cfg = AcceleratorResourcesConfig(
61+
accelerator=resource_map[torch.accelerator.current_accelerator().type],
62+
num_workers=8,
63+
num_cpus_per_worker=8,
64+
cpu_memory_per_worker=16 * 1024**3, # 16 GB
65+
)
66+
self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)
67+
68+
self.rollout_cfg = RolloutConfig(
69+
env="test_mock_rollout",
70+
model_path=MODEL_PATH,
71+
model_name=os.path.basename(MODEL_PATH).lower(),
72+
tokenizer_path=MODEL_PATH,
73+
tensor_parallel_size=1,
74+
context_length=self.max_prompt_length + self.max_response_length,
75+
max_retry_per_worker=2
76+
)
77+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
78+
79+
self.dataflow_cfg = DataFlowConfig(
80+
max_concurrent=self.max_concurrent,
81+
global_batch_size=self.global_batch_size,
82+
max_retry_times=self.max_retry_times
83+
)
84+
train_dataset_cfg = [{
85+
"dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH),
86+
"tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length),
87+
}]
88+
dataloader_cfg = DataloaderConfig(
89+
collator='fake_collator',
90+
pack_level='none',
91+
group_by_length=False,
92+
)
93+
self.replay_buffer_cfg = ReplayBufferConfig(
94+
dataset_cfg=train_dataset_cfg,
95+
dataloader_cfg=dataloader_cfg,
96+
tokenizer=tokenizer,
97+
)
98+
99+
def tearDown(self):
100+
ray.shutdown()
101+
102+
def _run_mock_test(self, mock_controller_cls, error_name: str):
103+
rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg)
104+
self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller)
105+
self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env)
106+
107+
completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3))
108+
109+
status = ray.get(self.test_dataflow.get_replaybuffer_status.remote())
110+
print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}")
111+
self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.")
112+
self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.")
113+
self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.")
114+
115+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
116+
def test_rollout_with_timeout_mock(self):
117+
self._run_mock_test(MockTimeoutRolloutController, "timeout")
118+
119+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
120+
def test_rollout_with_request_error_mock(self):
121+
self._run_mock_test(MockRequestErrorRolloutController, "request error")
122+
123+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
124+
def test_rollout_with_client_error_mock(self):
125+
self._run_mock_test(MockClientErrorRolloutController, "client error")
126+
127+
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
128+
def test_rollout_with_server_error_mock(self):
129+
self._run_mock_test(MockServerErrorRolloutController, "server error")
130+
131+
if __name__ == "__main__":
132+
unittest.main()

tests/ray/test_rollout.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -114,33 +114,11 @@ def setUp(self):
114114
def tearDown(self):
115115
ray.shutdown()
116116

117-
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
118-
def test_lmdeploy_dataflow_with_failed_request(self):
119-
failed_dataflow_cfg = DataFlowConfig(
120-
env="test",
121-
max_concurrent=1,
122-
prompt_repeat_k=2,
123-
global_batch_size=1,
124-
enable_partial_rollout=0,
125-
)
126-
self.test_env = SingleTurnEnvironment.remote(
127-
"test_env",
128-
self.pg,
129-
rollout_cfg=self.rollout_cfg,
130-
)
131-
self.test_flow = DataFlow.remote("test_env",
132-
failed_dataflow_cfg,
133-
self.replay_buffer_cfg,
134-
self.test_env
135-
)
136-
sample_params = SampleParams(temperature=2.5) # invalid temperature to trigger error
137-
with self.assertRaises(AssertionError):
138-
ray.get(self.test_flow.run.remote(num=1, sample_params=sample_params), timeout=300)
139-
117+
140118
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
141119
def test_lmdeploy_generate(self):
142120
sample_params = SampleParams(temperature=0.0)
143-
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
121+
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
144122
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
145123

146124
self.assertEqual(res1.finish_reason, "stop")
@@ -208,7 +186,7 @@ def test_lmdeploy_turbomind_generate(self):
208186
from xtuner.v1.ray.rollout import LMDeployWorker
209187
self.rollout_cfg.extra_rollout_config["lmdeploy_backend"] = "turbomind"
210188
sample_params = SampleParams(temperature=0.0)
211-
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
189+
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
212190
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
213191
res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
214192
self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}")
@@ -219,7 +197,7 @@ def test_sglang_generate(self):
219197
from xtuner.v1.ray.rollout import SGLangWorker
220198
self.rollout_cfg.launch_server_method="multiprocessing"
221199
sample_params = SampleParams(temperature=0.0)
222-
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
200+
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
223201
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
224202
self.assertEqual(res1.finish_reason, "stop")
225203
print("Response from SGLang infer:", res1)

tests/ray/test_update_weight.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def test_lmdeploy_update_weight_and_generate(self):
109109
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)
110110

111111
# init rollout_update
112-
rollout_controller = RolloutController.remote(
112+
rollout_controller = ray.remote(RolloutController).remote(
113113
self.rollout_cfg,
114114
self.pg,
115115
)
@@ -129,7 +129,7 @@ def test_lmdeploy_update_weight_and_generate(self):
129129

130130
# init rollout_ref
131131
self.rollout_cfg.skip_load_weights = False
132-
rollout_controller_ref = RolloutController.remote(
132+
rollout_controller_ref = ray.remote(RolloutController).remote(
133133
self.rollout_cfg,
134134
self.pg,
135135
)

xtuner/v1/data_proto/rl_data.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class RLRolloutResponseItem(BaseModel):
6969
response: Optional[str] = None
7070
response_ids: Optional[List[int]] = None
7171
num_return_tokens: Optional[int] = None
72-
finish_reason: Optional[str] = None
72+
finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped"
7373
logprobs: Optional[List[float]] = None
7474
extra_info: Dict[str, Any] = dict()
7575

@@ -153,7 +153,8 @@ def check_dataflow_item(group_data_items):
153153

154154
# 如果存在abort的状态,相当于跳过检查,下次会重新rollout
155155
is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items)
156-
if is_abort:
156+
is_skipped = any(item.env.rollout.finish_reason == "skipped" for item in group_data_items)
157+
if is_abort or is_skipped:
157158
return True
158159

159160
no_failures = all(item.env.rollout.finish_reason != "failed" for item in group_data_items)

xtuner/v1/ray/base/accelerator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,9 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup):
404404
rank_bundle_idx_list = []
405405
for rank, bundle_idx in enumerate(sorted_bundle_idxs):
406406
worker = worker_cls.options(
407-
placement_group=pg, placement_group_bundle_index=bundle_idx, **pg_options
407+
placement_group=pg,
408+
placement_group_bundle_index=bundle_idx,
409+
**pg_options,
408410
).remote(worker_config, rank, master_addr, master_port, world_size, device_type)
409411
workers_list.append(worker)
410412
rank_bundle_idx_list.append((rank, bundle_idx))

xtuner/v1/ray/config/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ class RolloutConfig(BaseModel):
198198
help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.',
199199
),
200200
] = {}
201+
max_retry_per_worker: Annotated[
202+
Optional[int],
203+
Parameter(
204+
group=infer_group,
205+
help="Maximum number of retries per rollout worker before deactivation.",
206+
),
207+
] = None
201208
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
202209

203210
def model_post_init(self, __context: Any) -> None:
@@ -259,6 +266,9 @@ def model_post_init(self, __context: Any) -> None:
259266
else:
260267
self.rollout_max_batch_size_per_instance = 128
261268

269+
if self.max_retry_per_worker is None:
270+
self.max_retry_per_worker = self.rollout_max_batch_size_per_instance
271+
262272
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
263273

264274

xtuner/v1/ray/dataflow/flow.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def __init__(
115115
self.finished_samples_count = 0
116116
self.unfinished_samples_count = 0
117117
self.failed_samples_count = 0
118+
self.skipped_sample_count = 0
118119
self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow")
119120
self.target_batch_size = self.config.global_batch_size
120121
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
@@ -192,6 +193,13 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
192193
f"Dataflow item check failed for {group_data_items[0].uid.action_id} response. Returning meta for retry."
193194
)
194195
return group_data_items
196+
if any(item.env.rollout.finish_reason == "skipped" for item in group_data_items):
197+
self.logger.warning(
198+
f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request."
199+
)
200+
self.logger.debug(f"Worker task skipped successfully for {group_data_items[0].uid.action_id}.")
201+
self.skipped_sample_count += 1
202+
return
195203

196204
# step 3: filter
197205
filtered_group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
@@ -228,7 +236,8 @@ async def concurrent_task_runner(self):
228236
next_update_threshold = update_step
229237
while (
230238
self.finished_samples_count < self.target_batch_size
231-
and self.failed_samples_count < self.target_batch_size
239+
and self.failed_samples_count < self.target_batch_size * self.config.max_retry_times
240+
and self.skipped_sample_count < self.target_batch_size * self.config.max_retry_times
232241
):
233242
if self.finished_samples_count >= next_update_threshold:
234243
pbar.n = self.finished_samples_count
@@ -253,14 +262,14 @@ async def concurrent_task_runner(self):
253262
if result[0].extra_info.retry_times < self.config.max_retry_times:
254263
# If the retry count is less than max_retry_times, retry the task
255264
self.logger.info(
256-
f"Retrying task for {result[0].data}. Retry count: {result[0].extra_info.retry_times}"
265+
f"Retrying task for {result[0].uid.action_id}. Retry count: {result[0].extra_info.retry_times}"
257266
)
258267
retry_task = create_task(self.worker_task(group_samples_for_retry=result))
259268
pending_tasks.add(retry_task)
260269
else:
261270
self.failed_samples_count += 1
262271
self.logger.error(
263-
f"Max retry reached for {result[0].data}. Not retrying. Current failed count: {self.failed_samples_count}"
272+
f"Max retry reached for {result[0].uid.action_id}. Not retrying. Current failed count: {self.failed_samples_count}"
264273
)
265274
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
266275
waiting_tasks = pending_tasks
@@ -270,8 +279,10 @@ async def concurrent_task_runner(self):
270279

271280
if self.finished_samples_count >= self.target_batch_size:
272281
self.logger.info("Target batch size reached. Pausing env controller.")
273-
if self.failed_samples_count >= self.target_batch_size:
282+
if self.failed_samples_count >= self.target_batch_size * self.config.max_retry_times:
274283
self.logger.info("Max failed samples reached. Pausing env controller.")
284+
if self.skipped_sample_count >= self.target_batch_size * self.config.max_retry_times:
285+
self.logger.info("Max skipped samples reached. Pausing env controller.")
275286

276287
# NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()`
277288
# would be queued behind many worker tasks, causing a significant delay.
@@ -344,6 +355,7 @@ async def run(
344355
self.finished_samples_count = 0
345356
self.unfinished_samples_count = 0
346357
self.failed_samples_count = 0
358+
self.skipped_sample_count = 0
347359
self.target_batch_size = num if num and num > 0 else self.config.global_batch_size
348360
self.logger.info(f"Start generate dataflow and target batch size set to {self.target_batch_size}.")
349361
self.sample_params = sample_params if sample_params else self.config.sample_params

0 commit comments

Comments
 (0)