Skip to content

Commit 1ed910f

Browse files
committed
[Refactor] refactor http request error
1 parent 7302ba5 commit 1ed910f

File tree

7 files changed

+223
-82
lines changed

7 files changed

+223
-82
lines changed

xtuner/v1/data_proto/rl_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ class RLRolloutResponseItem(BaseModel):
6969
response: Optional[str] = None
7070
response_ids: Optional[List[int]] = None
7171
num_return_tokens: Optional[int] = None
72-
finish_reason: Optional[str] = None
72+
finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped"
7373
logprobs: Optional[List[float]] = None
7474
extra_info: Dict[str, Any] = dict()
7575

xtuner/v1/ray/config/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ class RolloutConfig(BaseModel):
198198
help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.',
199199
),
200200
] = {"lmdeploy_log_level": "CRITICAL", "lmdeploy_uvicorn_log_level": "CRITICAL"}
201+
max_retry_per_worker: Annotated[
202+
Optional[int],
203+
Parameter(
204+
group=infer_group,
205+
help="Maximum number of retries per rollout worker before deactivation.",
206+
),
207+
] = None
201208
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
202209

203210
def __init__(self, **kwargs):
@@ -268,6 +275,9 @@ def __init__(self, **kwargs):
268275
else:
269276
kwargs["rollout_max_batch_size_per_instance"] = 128
270277

278+
if "max_retry_per_worker" not in kwargs:
279+
kwargs["max_retry_per_worker"] = int(kwargs["rollout_max_batch_size_per_instance"] * 0.1)
280+
271281
super().__init__(**kwargs)
272282
self.worker_log_dir.mkdir(parents=True, exist_ok=True)
273283

xtuner/v1/ray/dataflow/flow.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,11 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
193193
f"Dataflow item check failed for {group_data_items[0].uid.action_id} response. Returning meta for retry."
194194
)
195195
return group_data_items
196+
if any(item.env.rollout.finish_reason == "skipped" for item in group_data_items):
197+
self.logger.warning(
198+
f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request."
199+
)
200+
return
196201

197202
# step 3: filter
198203
filtered_group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]

xtuner/v1/ray/environment/single_turn_env.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,18 @@ async def generate(
6262
and state from the rollout controller.
6363
"""
6464
if self.rollout_controller:
65-
# 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦
66-
# 每个模块返回独立的data item, 在env中进行更新
67-
response_future = [
68-
self.rollout_controller.rollout.remote(
65+
response_future = []
66+
for sample in group_data_items:
67+
sample.data.extra_info["root_id"] = sample.uid.root_id
68+
sample.data.extra_info["action_id"] = sample.uid.action_id
69+
fut = self.rollout_controller.rollout.remote(
6970
prompt=sample.data.messages,
7071
input_ids=sample.data.input_ids,
7172
sample_params=sample_params,
7273
extra_params=extra_params,
7374
extra_info=sample.data.extra_info,
7475
)
75-
for sample in group_data_items
76-
]
76+
response_future.append(fut)
7777
try:
7878
rollout_responses = await asyncio.wait_for(
7979
asyncio.gather(*response_future), timeout=self.rollout_timeout
@@ -109,8 +109,7 @@ async def run(
109109
"""
110110
group_data_items = await self.generate(group_data_items, sample_params, extra_params) # type: ignore[assignment]
111111
skip_judger = any(
112-
item.env.rollout.finish_reason == "abort" or item.env.rollout.finish_reason == "failed"
113-
for item in group_data_items
112+
item.env.rollout.finish_reason in ["failed", "skipped", "abort"] for item in group_data_items
114113
)
115114
if self.judger_controller and not skip_judger:
116115
try:

xtuner/v1/ray/rollout/controller.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -106,15 +106,17 @@ def __init__(
106106
self.num_workers = 0
107107
self.worker_server_urls: List[str] = []
108108
self.active_rollout_workers: List[RolloutWorker] = []
109-
self.active_rollout_workers_status: Dict = {}
109+
self.active_workers_to_status: Dict[RolloutWorker, bool] = {}
110+
self.active_url_to_workers: Dict[str, RolloutWorker] = {}
111+
self.url_failed_counts: Dict[str, int] = {}
110112
self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True)
111113
self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
112114
self._get_worker_cls(), infer_config, placement_group
113115
)
114116
self.engine_mesh_list, self.server_url_dict = self.init_workers()
115117
self.start_api_server()
116118
# todo(@duanyanhui): add router to replace native round robin
117-
self.router = SessionRouter(self.active_rollout_workers_status)
119+
self.router = SessionRouter(self.active_workers_to_status)
118120
self.sample_params = SampleParams().dict()
119121
# note: 目前默认使用return_token_ids和return_logprob,并且不使用流式
120122
self.extra_params = dict(
@@ -237,7 +239,10 @@ def init_workers(self):
237239
)
238240
self._update_active_workers_and_urls()
239241
self.worker_server_urls = list(self.worker_server_urls_map.values())
240-
self.active_rollout_workers_status = {worker: True for worker in self.active_rollout_workers}
242+
self.logger.info(f"Rollout worker server URLs: {self.worker_server_urls}")
243+
self.active_workers_to_status = {worker: True for worker in self.active_rollout_workers}
244+
self.active_url_to_workers = dict(zip(self.worker_server_urls, self.active_rollout_workers))
245+
self.url_failed_counts = {url: 0 for url in self.worker_server_urls}
241246
return engine_mesh_list, self.worker_server_urls_map
242247

243248
def check_active_workers(self):
@@ -254,9 +259,19 @@ def check_active_workers(self):
254259
for idx, status in enumerate(active_worker_response):
255260
if not status:
256261
self.logger.info(
257-
f"Rollout worker {self.active_rollout_workers[idx]} is unhealthy. Removing it from active workers."
262+
f"Rollout worker {self.worker_server_urls[idx]} is unhealthy. Removing it from active workers."
258263
)
259-
self.active_rollout_workers_status[self.active_rollout_workers[idx]] = False
264+
self.active_workers_to_status[self.active_rollout_workers[idx]] = False
265+
266+
def deactivate_worker_by_url(self, url):
267+
self.url_failed_counts[url] += 1
268+
if self.url_failed_counts[url] < self.config.max_retry_per_worker:
269+
self.logger.warning(
270+
f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet."
271+
)
272+
return
273+
inactive_workers = self.active_url_to_workers.get(url)
274+
self.active_workers_to_status[inactive_workers] = False
260275

261276
async def rollout(
262277
self,
@@ -296,7 +311,6 @@ async def rollout(
296311
self.sample_params.update(sample_params.dict() if sample_params else {})
297312
self.extra_params.update(extra_params if extra_params else {})
298313
if self.print_params_flag:
299-
# 通过print_params_flag控制只打印一次参数
300314
self.logger.info(f"Rollout with sample params: {self.sample_params}, extra params: {self.extra_params}")
301315
self.print_params_flag = False
302316
assert prompt is not None or input_ids is not None, "Either prompt or input_ids must be provided."
@@ -311,8 +325,18 @@ async def rollout(
311325
extra_info=extra_info,
312326
)
313327
try:
314-
response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout)
315-
return response
328+
response, http_result = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout)
329+
if http_result.is_success:
330+
return response
331+
elif http_result.is_retryable or http_result.is_server_error:
332+
response.finish_reason = "failed"
333+
return response
334+
elif http_result.is_client_error:
335+
response.finish_reason = "skipped"
336+
return response
337+
else: # unknown error
338+
raise RuntimeError("Unknown error occurred during rollout. Error message: ", http_result.error_message)
339+
316340
except asyncio.TimeoutError:
317341
self.logger.error("Get response from rollout worker timeout and return the failed response.")
318342
failed_response = RLRolloutResponseItem(
@@ -409,7 +433,7 @@ def _broadcast_to_active_workers(self, method_name: str, block: bool):
409433
A list of futures if `block` is False, otherwise a list of results.
410434
"""
411435
futures = []
412-
for worker, status in self.active_rollout_workers_status.items():
436+
for worker, status in self.active_workers_to_status.items():
413437
if status:
414438
futures.append(getattr(worker, method_name).remote())
415439

xtuner/v1/ray/rollout/worker.py

Lines changed: 42 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker
2222
from xtuner.v1.ray.config import RolloutConfig
2323
from xtuner.v1.utils import get_logger
24+
from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult
2425

2526

2627
class RolloutWorker(SingleAcceleratorWorker):
@@ -285,9 +286,8 @@ def _check_infer_engine_version(self, return_token_ids: bool):
285286
)
286287
self.check_flag = False
287288

288-
async def _safe_post_request(self, url, headers, payload) -> Tuple[Optional[httpx.Response], bool, Optional[str]]:
289+
async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
289290
try:
290-
# new_url = self.server_url[-2] + str(int(self.server_url[-1]) + 1) + "'"
291291
req = self.client.build_request(
292292
"POST",
293293
url,
@@ -296,34 +296,11 @@ async def _safe_post_request(self, url, headers, payload) -> Tuple[Optional[http
296296
)
297297
r = await self.client.send(req)
298298
r.raise_for_status()
299-
return r, True, None
300-
# NOTE(@duanyanhui): 目前只有TimeoutException时,第二个返回值为True ,即continue_rollout=True,不影响主程序正常运行
301-
# 其他错误都认为是请求失败,会通过assert进行报错,并且根据错误类型返回不同的error msg.
302-
except httpx.TimeoutException as e:
303-
error_msg = f"create_request error: Request to {url} timed out: {e}"
304-
self.logger.warning(error_msg)
305-
return None, True, None
306-
except httpx.HTTPStatusError as e:
307-
if e.response.status_code == 400:
308-
log_payload = copy.deepcopy(payload)
309-
if "input_ids" in log_payload and log_payload["input_ids"] is not None:
310-
log_payload["input_ids"] = str(log_payload["input_ids"])
311-
error_msg = (
312-
f"Bad Request (400) Error for {url} with payload {log_payload}. Server response: {e.response.text}"
313-
)
314-
return None, False, error_msg
315-
else:
316-
error_msg = f"HTTP error occurred for {url}: {e.response.status_code} - {e.response.text}"
317-
return None, False, error_msg
318-
except httpx.RequestError as e:
319-
log_payload = copy.deepcopy(payload)
320-
if "input_ids" in log_payload and log_payload["input_ids"] is not None:
321-
log_payload["input_ids"] = str(log_payload["input_ids"])
322-
error_msg = f"Request Error occurred while requesting {payload} to {url}: {e}"
323-
return None, False, error_msg
299+
return HttpRequestResult(response=r)
324300
except Exception as e:
325-
error_msg = f"Unexpected Error occurred: {e} with traceback: \n {traceback.format_exc()}"
326-
return None, False, error_msg
301+
error_type = HttpRequestErrorType.from_exception(e)
302+
result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload)
303+
return result
327304

328305
async def rollout_task(
329306
self,
@@ -335,58 +312,59 @@ async def rollout_task(
335312
extra_params: dict,
336313
format: str,
337314
extra_info: dict,
338-
) -> RLRolloutResponseItem:
339-
uid = str(uuid.uuid4())
315+
) -> Tuple[RLRolloutResponseItem, HttpRequestResult]:
316+
uid = extra_info.get("action_id", str(uuid.uuid4()))
340317
response = None
341-
failed_rollout_response = RLRolloutResponseItem(
342-
finish_reason="failed",
343-
)
318+
failed_rollout_response = RLRolloutResponseItem(finish_reason="failed")
344319
self._check_infer_engine_version("return_token_ids" in extra_params and extra_params["return_token_ids"])
345320

346321
if format == "openai":
347322
openai_prompts, openai_tools = prompts, tools
348323
else:
349324
openai_prompts, openai_tools = self._adapt_input_to_openai_spec(prompts, tools, tool_choice)
325+
350326
if "return_token_ids" in extra_params and extra_params["return_token_ids"]:
351-
response, continue_rollout, error_msg = await self._create_request(
352-
f"{self.server_url}/{self.endpoints['generate']}",
353-
openai_prompts,
354-
input_ids,
355-
openai_tools,
356-
tool_choice,
357-
sample_params=sample_params,
358-
extra_params=extra_params,
359-
extra_info=extra_info,
360-
)
327+
endpoint_url = f"{self.server_url}/{self.endpoints['generate']}"
361328
else:
362-
assert prompts is not None, "prompts should not be None when you call v1/chat/completions API"
363-
response, continue_rollout, error_msg = await self._create_request(
364-
f"{self.server_url}/{self.endpoints['v1/chat/completions']}",
365-
openai_prompts,
366-
None,
367-
openai_tools,
368-
tool_choice,
369-
sample_params=sample_params,
370-
extra_params=extra_params,
371-
extra_info=extra_info,
372-
)
373-
assert continue_rollout, (
374-
f"Unhandled error occurred during rollout request creation, You should check infer engine or input params. \n Error message: {error_msg}"
329+
endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}"
330+
331+
http_result = await self._create_request(
332+
endpoint_url,
333+
openai_prompts,
334+
None,
335+
openai_tools,
336+
tool_choice,
337+
sample_params=sample_params,
338+
extra_params=extra_params,
339+
extra_info=extra_info,
375340
)
376-
if response:
341+
342+
if http_result.response is not None:
377343
try:
378344
rollout_response = (
379345
await self._handle_stream_response(uid, sample_params, extra_params, response)
380346
if extra_params["stream"]
381347
else await self._handle_non_stream_response(uid, sample_params, extra_params, response)
382348
)
383349
finally:
384-
if hasattr(response, "aclose"):
385-
await response.aclose()
386-
return rollout_response
350+
if hasattr(http_result.response, "aclose"):
351+
await http_result.response.aclose()
352+
return rollout_response, http_result
387353
else:
388-
self.logger.warning(f"Retrying rollout for {uid} due to httpx timeout")
389-
return failed_rollout_response
354+
if http_result.is_retryable:
355+
self.logger.warning(f"Retryable error occurred during rollout request {uid} to {http_result.url}")
356+
return failed_rollout_response, http_result
357+
elif http_result.is_server_error:
358+
self.logger.error(
359+
f"Server error during rollout request {uid} to {http_result.url}, please check the server logs."
360+
)
361+
http_result.url = self.server_url
362+
return failed_rollout_response, http_result
363+
else: # http_result.is_client_error:
364+
self.logger.error(
365+
f"Client error during rollout request {uid} to {http_result.url} and skip this request."
366+
)
367+
return failed_rollout_response, http_result
390368

391369
async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem:
392370
last_trajectory = ""
@@ -556,7 +534,7 @@ async def rollout(
556534
extra_params: dict = dict(),
557535
format: str = "openai",
558536
extra_info: dict = dict(),
559-
) -> RLRolloutResponseItem:
537+
) -> Tuple[RLRolloutResponseItem, HttpRequestResult]:
560538
"""Public method to initiate a rollout.
561539
562540
Args:

0 commit comments

Comments
 (0)