2121from xtuner .v1 .ray .base import AutoAcceleratorWorkers , SingleAcceleratorWorker
2222from xtuner .v1 .ray .config import RolloutConfig
2323from xtuner .v1 .utils import get_logger
24+ from xtuner .v1 .utils .httpx_utils import HttpRequestErrorType , HttpRequestResult
2425
2526
2627class RolloutWorker (SingleAcceleratorWorker ):
@@ -285,9 +286,8 @@ def _check_infer_engine_version(self, return_token_ids: bool):
285286 )
286287 self .check_flag = False
287288
288- async def _safe_post_request (self , url , headers , payload ) -> Tuple [ Optional [ httpx . Response ], bool , Optional [ str ]] :
289+ async def _safe_post_request (self , url , headers , payload ) -> HttpRequestResult :
289290 try :
290- # new_url = self.server_url[-2] + str(int(self.server_url[-1]) + 1) + "'"
291291 req = self .client .build_request (
292292 "POST" ,
293293 url ,
@@ -296,34 +296,11 @@ async def _safe_post_request(self, url, headers, payload) -> Tuple[Optional[http
296296 )
297297 r = await self .client .send (req )
298298 r .raise_for_status ()
299- return r , True , None
300- # NOTE(@duanyanhui): 目前只有TimeoutException时,第二个返回值为True ,即continue_rollout=True,不影响主程序正常运行
301- # 其他错误都认为是请求失败,会通过assert进行报错,并且根据错误类型返回不同的error msg.
302- except httpx .TimeoutException as e :
303- error_msg = f"create_request error: Request to { url } timed out: { e } "
304- self .logger .warning (error_msg )
305- return None , True , None
306- except httpx .HTTPStatusError as e :
307- if e .response .status_code == 400 :
308- log_payload = copy .deepcopy (payload )
309- if "input_ids" in log_payload and log_payload ["input_ids" ] is not None :
310- log_payload ["input_ids" ] = str (log_payload ["input_ids" ])
311- error_msg = (
312- f"Bad Request (400) Error for { url } with payload { log_payload } . Server response: { e .response .text } "
313- )
314- return None , False , error_msg
315- else :
316- error_msg = f"HTTP error occurred for { url } : { e .response .status_code } - { e .response .text } "
317- return None , False , error_msg
318- except httpx .RequestError as e :
319- log_payload = copy .deepcopy (payload )
320- if "input_ids" in log_payload and log_payload ["input_ids" ] is not None :
321- log_payload ["input_ids" ] = str (log_payload ["input_ids" ])
322- error_msg = f"Request Error occurred while requesting { payload } to { url } : { e } "
323- return None , False , error_msg
299+ return HttpRequestResult (response = r )
324300 except Exception as e :
325- error_msg = f"Unexpected Error occurred: { e } with traceback: \n { traceback .format_exc ()} "
326- return None , False , error_msg
301+ error_type = HttpRequestErrorType .from_exception (e )
302+ result = HttpRequestResult (error_type = error_type , exception = e , url = url , payload = payload )
303+ return result
327304
328305 async def rollout_task (
329306 self ,
@@ -335,58 +312,59 @@ async def rollout_task(
335312 extra_params : dict ,
336313 format : str ,
337314 extra_info : dict ,
338- ) -> RLRolloutResponseItem :
339- uid = str (uuid .uuid4 ())
315+ ) -> Tuple [ RLRolloutResponseItem , HttpRequestResult ] :
316+ uid = extra_info . get ( "action_id" , str (uuid .uuid4 () ))
340317 response = None
341- failed_rollout_response = RLRolloutResponseItem (
342- finish_reason = "failed" ,
343- )
318+ failed_rollout_response = RLRolloutResponseItem (finish_reason = "failed" )
344319 self ._check_infer_engine_version ("return_token_ids" in extra_params and extra_params ["return_token_ids" ])
345320
346321 if format == "openai" :
347322 openai_prompts , openai_tools = prompts , tools
348323 else :
349324 openai_prompts , openai_tools = self ._adapt_input_to_openai_spec (prompts , tools , tool_choice )
325+
350326 if "return_token_ids" in extra_params and extra_params ["return_token_ids" ]:
351- response , continue_rollout , error_msg = await self ._create_request (
352- f"{ self .server_url } /{ self .endpoints ['generate' ]} " ,
353- openai_prompts ,
354- input_ids ,
355- openai_tools ,
356- tool_choice ,
357- sample_params = sample_params ,
358- extra_params = extra_params ,
359- extra_info = extra_info ,
360- )
327+ endpoint_url = f"{ self .server_url } /{ self .endpoints ['generate' ]} "
361328 else :
362- assert prompts is not None , "prompts should not be None when you call v1/chat/completions API"
363- response , continue_rollout , error_msg = await self ._create_request (
364- f"{ self .server_url } /{ self .endpoints ['v1/chat/completions' ]} " ,
365- openai_prompts ,
366- None ,
367- openai_tools ,
368- tool_choice ,
369- sample_params = sample_params ,
370- extra_params = extra_params ,
371- extra_info = extra_info ,
372- )
373- assert continue_rollout , (
374- f"Unhandled error occurred during rollout request creation, You should check infer engine or input params. \n Error message: { error_msg } "
329+ endpoint_url = f"{ self .server_url } /{ self .endpoints ['v1/chat/completions' ]} "
330+
331+ http_result = await self ._create_request (
332+ endpoint_url ,
333+ openai_prompts ,
334+ None ,
335+ openai_tools ,
336+ tool_choice ,
337+ sample_params = sample_params ,
338+ extra_params = extra_params ,
339+ extra_info = extra_info ,
375340 )
376- if response :
341+
342+ if http_result .response is not None :
377343 try :
378344 rollout_response = (
379345 await self ._handle_stream_response (uid , sample_params , extra_params , response )
380346 if extra_params ["stream" ]
381347 else await self ._handle_non_stream_response (uid , sample_params , extra_params , response )
382348 )
383349 finally :
384- if hasattr (response , "aclose" ):
385- await response .aclose ()
386- return rollout_response
350+ if hasattr (http_result . response , "aclose" ):
351+ await http_result . response .aclose ()
352+ return rollout_response , http_result
387353 else :
388- self .logger .warning (f"Retrying rollout for { uid } due to httpx timeout" )
389- return failed_rollout_response
354+ if http_result .is_retryable :
355+ self .logger .warning (f"Retryable error occurred during rollout request { uid } to { http_result .url } " )
356+ return failed_rollout_response , http_result
357+ elif http_result .is_server_error :
358+ self .logger .error (
359+ f"Server error during rollout request { uid } to { http_result .url } , please check the server logs."
360+ )
361+ http_result .url = self .server_url
362+ return failed_rollout_response , http_result
363+ else : # http_result.is_client_error:
364+ self .logger .error (
365+ f"Client error during rollout request { uid } to { http_result .url } and skip this request."
366+ )
367+ return failed_rollout_response , http_result
390368
391369 async def _handle_stream_response (self , uid , sample_params , extra_params , response ) -> RLRolloutResponseItem :
392370 last_trajectory = ""
@@ -556,7 +534,7 @@ async def rollout(
556534 extra_params : dict = dict (),
557535 format : str = "openai" ,
558536 extra_info : dict = dict (),
559- ) -> RLRolloutResponseItem :
537+ ) -> Tuple [ RLRolloutResponseItem , HttpRequestResult ] :
560538 """Public method to initiate a rollout.
561539
562540 Args:
0 commit comments