From 2cb5c1ea905021f2d9b40039ffb5977aba8dceb1 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 17:04:23 +0800 Subject: [PATCH 01/18] [Refactor] refactor http request error --- xtuner/v1/data_proto/rl_data.py | 2 +- xtuner/v1/ray/config/worker.py | 10 ++ xtuner/v1/ray/dataflow/flow.py | 5 + xtuner/v1/ray/environment/single_turn_env.py | 15 ++- xtuner/v1/ray/rollout/controller.py | 42 +++++-- xtuner/v1/ray/rollout/worker.py | 103 ++++++--------- xtuner/v1/utils/httpx_utils.py | 125 +++++++++++++++++++ 7 files changed, 222 insertions(+), 80 deletions(-) create mode 100644 xtuner/v1/utils/httpx_utils.py diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index ca3318ea0..c723d0f82 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -69,7 +69,7 @@ class RLRolloutResponseItem(BaseModel): response: Optional[str] = None response_ids: Optional[List[int]] = None num_return_tokens: Optional[int] = None - finish_reason: Optional[str] = None + finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped" logprobs: Optional[List[float]] = None extra_info: Dict[str, Any] = dict() diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index ba91d5ca5..b03301b78 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -198,6 +198,13 @@ class RolloutConfig(BaseModel): help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.', ), ] = {} + max_retry_per_worker: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum number of retries per rollout worker before deactivation.", + ), + ] = None worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" def model_post_init(self, __context: Any) -> None: @@ -259,6 +266,9 @@ def model_post_init(self, __context: Any) -> None: else: self.rollout_max_batch_size_per_instance = 128 + if self.max_retry_per_worker is None: + self.max_retry_per_worker = self.rollout_max_batch_size_per_instance + self.worker_log_dir.mkdir(parents=True, exist_ok=True) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index c0a46778b..545ed4cc4 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -192,6 +192,11 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte f"Dataflow item check failed for {group_data_items[0].uid.action_id} response. Returning meta for retry." ) return group_data_items + if any(item.env.rollout.finish_reason == "skipped" for item in group_data_items): + self.logger.warning( + f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request." + ) + return # step 3: filter filtered_group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined] diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index 3024fbfd4..79d98ee37 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -62,18 +62,18 @@ async def generate( and state from the rollout controller. """ if self.rollout_controller: - # 在env中对输入的数据进行转换,是为了支持rollout_controller单独作为rollout engine使用,使各个模块进行解耦 - # 每个模块返回独立的data item, 在env中进行更新 - response_future = [ - self.rollout_controller.rollout.remote( + response_future = [] + for sample in group_data_items: + sample.data.extra_info["root_id"] = sample.uid.root_id + sample.data.extra_info["action_id"] = sample.uid.action_id + fut = self.rollout_controller.rollout.remote( prompt=sample.data.messages, input_ids=sample.data.input_ids, sample_params=sample_params, extra_params=extra_params, extra_info=sample.data.extra_info, ) - for sample in group_data_items - ] + response_future.append(fut) try: rollout_responses = await asyncio.wait_for( asyncio.gather(*response_future), timeout=self.rollout_timeout @@ -109,8 +109,7 @@ async def run( """ group_data_items = await self.generate(group_data_items, sample_params, extra_params) # type: ignore[assignment] skip_judger = any( - item.env.rollout.finish_reason == "abort" or item.env.rollout.finish_reason == "failed" - for item in group_data_items + item.env.rollout.finish_reason in ["failed", "skipped", "abort"] for item in group_data_items ) if self.judger_controller and not skip_judger: try: diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 70e58b8f0..4758fba19 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -106,7 +106,9 @@ def __init__( self.num_workers = 0 self.worker_server_urls: List[str] = [] self.active_rollout_workers: List[RolloutWorker] = [] - self.active_rollout_workers_status: Dict = {} + self.active_workers_to_status: Dict[RolloutWorker, bool] = {} + self.active_url_to_workers: Dict[str, RolloutWorker] = {} + self.url_failed_counts: Dict[str, int] = {} self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True) self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( self._get_worker_cls(), infer_config, placement_group @@ -114,7 +116,7 @@ def __init__( self.engine_mesh_list, self.server_url_dict = self.init_workers() self.start_api_server() # todo(@duanyanhui): add router to replace native round robin - self.router = SessionRouter(self.active_rollout_workers_status) + self.router = SessionRouter(self.active_workers_to_status) self.sample_params = SampleParams().dict() # note: 目前默认使用return_token_ids和return_logprob,并且不使用流式 self.extra_params = dict( @@ -237,7 +239,10 @@ def init_workers(self): ) self._update_active_workers_and_urls() self.worker_server_urls = list(self.worker_server_urls_map.values()) - self.active_rollout_workers_status = {worker: True for worker in self.active_rollout_workers} + self.logger.info(f"Rollout worker server URLs: {self.worker_server_urls}") + self.active_workers_to_status = {worker: True for worker in self.active_rollout_workers} + self.active_url_to_workers = dict(zip(self.worker_server_urls, self.active_rollout_workers)) + self.url_failed_counts = {url: 0 for url in self.worker_server_urls} return engine_mesh_list, self.worker_server_urls_map def check_active_workers(self): @@ -254,9 +259,19 @@ def check_active_workers(self): for idx, status in enumerate(active_worker_response): if not status: self.logger.info( - f"Rollout worker {self.active_rollout_workers[idx]} is unhealthy. Removing it from active workers." + f"Rollout worker {self.worker_server_urls[idx]} is unhealthy. Removing it from active workers." ) - self.active_rollout_workers_status[self.active_rollout_workers[idx]] = False + self.active_workers_to_status[self.active_rollout_workers[idx]] = False + + def deactivate_worker_by_url(self, url): + self.url_failed_counts[url] += 1 + if self.url_failed_counts[url] < self.config.max_retry_per_worker: + self.logger.warning( + f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet." + ) + return + inactive_workers = self.active_url_to_workers.get(url) + self.active_workers_to_status[inactive_workers] = False async def rollout( self, @@ -296,7 +311,6 @@ async def rollout( self.sample_params.update(sample_params.dict() if sample_params else {}) self.extra_params.update(extra_params if extra_params else {}) if self.print_params_flag: - # 通过print_params_flag控制只打印一次参数 self.logger.info(f"Rollout with sample params: {self.sample_params}, extra params: {self.extra_params}") self.print_params_flag = False assert prompt is not None or input_ids is not None, "Either prompt or input_ids must be provided." @@ -311,8 +325,18 @@ async def rollout( extra_info=extra_info, ) try: - response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout) - return response + response, http_result = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout) + if http_result.is_success: + return response + elif http_result.is_retryable or http_result.is_server_error: + response.finish_reason = "failed" + return response + elif http_result.is_client_error: + response.finish_reason = "skipped" + return response + else: # unknown error + raise RuntimeError("Unknown error occurred during rollout. Error message: ", http_result.error_message) + except asyncio.TimeoutError: self.logger.error("Get response from rollout worker timeout and return the failed response.") failed_response = RLRolloutResponseItem( @@ -409,7 +433,7 @@ def _broadcast_to_active_workers(self, method_name: str, block: bool): A list of futures if `block` is False, otherwise a list of results. """ futures = [] - for worker, status in self.active_rollout_workers_status.items(): + for worker, status in self.active_workers_to_status.items(): if status: futures.append(getattr(worker, method_name).remote()) diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 61b1dcd26..6921b2271 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -21,6 +21,7 @@ from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig from xtuner.v1.utils import get_logger +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult class RolloutWorker(SingleAcceleratorWorker): @@ -288,9 +289,8 @@ def _check_infer_engine_version(self, return_token_ids: bool): ) self.check_flag = False - async def _safe_post_request(self, url, headers, payload) -> Tuple[Optional[httpx.Response], bool, Optional[str]]: + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: - # new_url = self.server_url[-2] + str(int(self.server_url[-1]) + 1) + "'" req = self.client.build_request( "POST", url, @@ -299,34 +299,11 @@ async def _safe_post_request(self, url, headers, payload) -> Tuple[Optional[http ) r = await self.client.send(req) r.raise_for_status() - return r, True, None - # NOTE(@duanyanhui): 目前只有TimeoutException时,第二个返回值为True ,即continue_rollout=True,不影响主程序正常运行 - # 其他错误都认为是请求失败,会通过assert进行报错,并且根据错误类型返回不同的error msg. - except httpx.TimeoutException as e: - error_msg = f"create_request error: Request to {url} timed out: {e}" - self.logger.warning(error_msg) - return None, True, None - except httpx.HTTPStatusError as e: - if e.response.status_code == 400: - log_payload = copy.deepcopy(payload) - if "input_ids" in log_payload and log_payload["input_ids"] is not None: - log_payload["input_ids"] = str(log_payload["input_ids"]) - error_msg = ( - f"Bad Request (400) Error for {url} with payload {log_payload}. Server response: {e.response.text}" - ) - return None, False, error_msg - else: - error_msg = f"HTTP error occurred for {url}: {e.response.status_code} - {e.response.text}" - return None, False, error_msg - except httpx.RequestError as e: - log_payload = copy.deepcopy(payload) - if "input_ids" in log_payload and log_payload["input_ids"] is not None: - log_payload["input_ids"] = str(log_payload["input_ids"]) - error_msg = f"Request Error occurred while requesting {payload} to {url}: {e}" - return None, False, error_msg + return HttpRequestResult(response=r) except Exception as e: - error_msg = f"Unexpected Error occurred: {e} with traceback: \n {traceback.format_exc()}" - return None, False, error_msg + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + return result async def rollout_task( self, @@ -338,45 +315,35 @@ async def rollout_task( extra_params: dict, format: str, extra_info: dict, - ) -> RLRolloutResponseItem: - uid = str(uuid.uuid4()) + ) -> Tuple[RLRolloutResponseItem, HttpRequestResult]: + uid = extra_info.get("action_id", str(uuid.uuid4())) response = None - failed_rollout_response = RLRolloutResponseItem( - finish_reason="failed", - ) + failed_rollout_response = RLRolloutResponseItem(finish_reason="failed") self._check_infer_engine_version("return_token_ids" in extra_params and extra_params["return_token_ids"]) if format == "openai": openai_prompts, openai_tools = prompts, tools else: openai_prompts, openai_tools = self._adapt_input_to_openai_spec(prompts, tools, tool_choice) + if "return_token_ids" in extra_params and extra_params["return_token_ids"]: - response, continue_rollout, error_msg = await self._create_request( - f"{self.server_url}/{self.endpoints['generate']}", - openai_prompts, - input_ids, - openai_tools, - tool_choice, - sample_params=sample_params, - extra_params=extra_params, - extra_info=extra_info, - ) + endpoint_url = f"{self.server_url}/{self.endpoints['generate']}" else: - assert prompts is not None, "prompts should not be None when you call v1/chat/completions API" - response, continue_rollout, error_msg = await self._create_request( - f"{self.server_url}/{self.endpoints['v1/chat/completions']}", - openai_prompts, - None, - openai_tools, - tool_choice, - sample_params=sample_params, - extra_params=extra_params, - extra_info=extra_info, - ) - assert continue_rollout, ( - f"Unhandled error occurred during rollout request creation, You should check infer engine or input params. \n Error message: {error_msg}" + endpoint_url = f"{self.server_url}/{self.endpoints['v1/chat/completions']}" + + http_result = await self._create_request( + endpoint_url, + openai_prompts, + input_ids, + openai_tools, + tool_choice, + sample_params=sample_params, + extra_params=extra_params, + extra_info=extra_info, ) - if response: + + if http_result.response is not None: + response = http_result.response try: rollout_response = ( await self._handle_stream_response(uid, sample_params, extra_params, response) @@ -386,10 +353,22 @@ async def rollout_task( finally: if hasattr(response, "aclose"): await response.aclose() - return rollout_response + return rollout_response, http_result else: - self.logger.warning(f"Retrying rollout for {uid} due to httpx timeout") - return failed_rollout_response + if http_result.is_retryable: + self.logger.warning(f"Retryable error occurred during rollout request {uid} to {http_result.url}") + return failed_rollout_response, http_result + elif http_result.is_server_error: + self.logger.error( + f"Server error during rollout request {uid} to {http_result.url}, please check the server logs." + ) + http_result.url = self.server_url + return failed_rollout_response, http_result + else: # http_result.is_client_error: + self.logger.error( + f"Client error during rollout request {uid} to {http_result.url} and skip this request." + ) + return failed_rollout_response, http_result async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: last_trajectory = "" @@ -559,7 +538,7 @@ async def rollout( extra_params: dict = dict(), format: str = "openai", extra_info: dict = dict(), - ) -> RLRolloutResponseItem: + ) -> Tuple[RLRolloutResponseItem, HttpRequestResult]: """Public method to initiate a rollout. Args: diff --git a/xtuner/v1/utils/httpx_utils.py b/xtuner/v1/utils/httpx_utils.py new file mode 100644 index 000000000..68c8fb5f7 --- /dev/null +++ b/xtuner/v1/utils/httpx_utils.py @@ -0,0 +1,125 @@ +import traceback +from dataclasses import dataclass, field +from enum import IntEnum +from typing import Any, Dict, Optional + +import httpx + + +class HttpRequestErrorType(IntEnum): + """An enumeration for HTTP status codes and client-side request errors. + Inherits from IntEnum for direct integer comparison. + + Custom codes are used for client-side exceptions that do not have + an HTTP status code. + + Example: + if error_code == RequestErrorType.BAD_REQUEST: + print("Bad request from server!") + elif error_code == RequestErrorType.TIMEOUT_ERROR: + print("Client-side request timed out!") + """ + + # --- Custom Codes for Client-Side and Unhandled Errors --- + UNKNOWN_ERROR = -1 + TIMEOUT_ERROR = 0 + REQUEST_ERROR = 1 + # --- Standard HTTP Status Codes --- + SUCCESS = 200 + BAD_REQUEST = 400 + UNAUTHORIZED = 401 + FORBIDDEN = 403 + NOT_FOUND = 404 + REQUEST_TIMEOUT = 408 + TOO_MANY_REQUESTS = 429 + INTERNAL_SERVER_ERROR = 500 + BAD_GATEWAY = 502 + SERVICE_UNAVAILABLE = 503 + GATEWAY_TIMEOUT = 504 + + @classmethod + def from_exception(cls, e: Exception) -> "HttpRequestErrorType": + """Factory method to determine the RequestErrorType from a given + exception.""" + if isinstance(e, httpx.TimeoutException): + return cls.TIMEOUT_ERROR + + if isinstance(e, httpx.HTTPStatusError): + # Try to match the status code to an existing enum member. + # If not found, it's an unknown HTTP error, but we can still categorize it. + # For simplicity here, we'll just return the known ones or fall back. + try: + return cls(e.response.status_code) + except ValueError: + # The status code is not a defined member of our enum. + # We can decide to return UNKNOWN_ERROR or handle it differently. + return cls.UNKNOWN_ERROR + + if isinstance(e, httpx.RequestError): + # This check comes after its subclasses (TimeoutException, HTTPStatusError) + return cls.REQUEST_ERROR + + # For any other standard Python exception + return cls.UNKNOWN_ERROR + + +@dataclass +class HttpRequestResult: + response: Optional[httpx.Response] = None + error_type: HttpRequestErrorType = HttpRequestErrorType.SUCCESS + error_msg: Optional[str] = None + + exception: Optional[Exception] = field(default=None, repr=False) + url: Optional[str] = field(default=None, repr=False) + payload: Optional[Dict[str, Any]] = field(default=None, repr=False) + + def __post_init__(self): + """Generate error_msg automatically if an exception is provided and + error_msg is not already set.""" + # Only generate a message if one hasn't been provided already. + if self.error_msg is None and self.error_type != HttpRequestErrorType.SUCCESS: + if self.payload is not None and "input_ids" in self.payload: + self.payload["input_ids"] = str(self.payload["input_ids"]) + + default_messages = { + HttpRequestErrorType.UNKNOWN_ERROR: f"An unknown error {self.exception} occurred, Traceback: {traceback.format_exc()}", + HttpRequestErrorType.TIMEOUT_ERROR: "The request timed out.", + HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred occurred. TypeError: {type(self.exception)}", + HttpRequestErrorType.BAD_REQUEST: f"Bad Request (400): The server could not process the request {self.payload}", + HttpRequestErrorType.UNAUTHORIZED: "Unauthorized (401): Authentication failed or is required.", + HttpRequestErrorType.FORBIDDEN: "Forbidden (403): Access is denied.", + HttpRequestErrorType.NOT_FOUND: "Not Found (404): The resource was not found.", + HttpRequestErrorType.REQUEST_TIMEOUT: f"Request Timeout (408): The server timed out waiting for the request {self.payload}.", + HttpRequestErrorType.TOO_MANY_REQUESTS: "Too Many Requests (429): Rate limit exceeded.", + HttpRequestErrorType.INTERNAL_SERVER_ERROR: f"Internal Server Error (500) {self.exception} occurred in {self.url}, Traceback: {traceback.format_exc()}", + HttpRequestErrorType.BAD_GATEWAY: f"Bad Gateway (502) in {self.url}.", + HttpRequestErrorType.SERVICE_UNAVAILABLE: f"Service Unavailable (503) in {self.url}.", + HttpRequestErrorType.GATEWAY_TIMEOUT: f"Gateway Timeout (504) in {self.url}.", + } + + # Get the message from the map, or provide a generic fallback. + self.error_msg = default_messages.get( + self.error_type, f"An error occurred with status code: {self.error_type.value}" + ) + if self.error_type == HttpRequestErrorType.REQUEST_ERROR and self.exception: + if hasattr(self.exception, "__cause__") and self.exception.__cause__: + self.error_msg += f"__cause__: {self.exception.__cause__}" + + @property + def is_success(self) -> bool: + return self.error_type == HttpRequestErrorType.SUCCESS + + @property + def is_retryable(self) -> bool: + return self.error_type in { + HttpRequestErrorType.TIMEOUT_ERROR, + HttpRequestErrorType.REQUEST_ERROR, + } + + @property + def is_client_error(self) -> bool: + return 400 <= self.error_type < 500 + + @property + def is_server_error(self) -> bool: + return 500 <= self.error_type < 600 From c011186716383b56e4031cf679d957c163e2b4bf Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 20:53:45 +0800 Subject: [PATCH 02/18] add ut --- tests/ray/test_mock_rollout.py | 129 +++++++++++++++++++ tests/ray/test_rollout.py | 24 +--- xtuner/v1/ray/dataflow/flow.py | 5 +- xtuner/v1/ray/environment/base_env.py | 19 ++- xtuner/v1/ray/environment/single_turn_env.py | 25 +++- xtuner/v1/ray/evaluator.py | 2 +- xtuner/v1/ray/rollout/controller.py | 15 +-- xtuner/v1/ray/rollout/lmdeploy.py | 1 - xtuner/v1/ray/rollout/worker.py | 15 ++- xtuner/v1/utils/rl_test_utils.py | 69 ++++++++++ 10 files changed, 253 insertions(+), 51 deletions(-) create mode 100644 tests/ray/test_mock_rollout.py diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py new file mode 100644 index 000000000..460d5d922 --- /dev/null +++ b/tests/ray/test_mock_rollout.py @@ -0,0 +1,129 @@ +import os +import unittest +import ray +from transformers import AutoTokenizer +import torch +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers +from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.ray.environment import SingleTurnEnvironment +from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig +from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.ray.rollout.controller import RolloutController +# 导入 Mock Worker +from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker + + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +resource_map = {"npu": "NPU", "cuda": "GPU"} + +@ray.remote +class MockTimeoutRolloutController(RolloutController): + def _get_worker_cls(self): + return MockTimeoutRolloutWorker + +@ray.remote +class MockRequestErrorRolloutController(RolloutController): + def _get_worker_cls(self): + return MockRequestErrorRolloutWorker + +@ray.remote +class MockClientErrorRolloutController(RolloutController): + def _get_worker_cls(self): + return MockClientErrorRolloutWorker + +@ray.remote +class MockServerErrorRolloutController(RolloutController): + def _get_worker_cls(self): + return MockServerErrorRolloutWorker + +class TestMockRollout(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls): + del os.environ["XTUNER_USE_FA3"] + ray.shutdown() + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.global_batch_size = 3 + self.max_prompt_length = 4096 + self.max_response_length = 128 + self.max_concurrent = 3 + self.max_retry_times = 3 + + self.resources_cfg = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=8, + num_cpus_per_worker=8, + cpu_memory_per_worker=16 * 1024**3, # 16 GB + ) + self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) + + self.rollout_cfg = RolloutConfig( + env="test_mock_rollout", + model_path=MODEL_PATH, + model_name=os.path.basename(MODEL_PATH).lower(), + tokenizer_path=MODEL_PATH, + tensor_parallel_size=1, + context_length=self.max_prompt_length + self.max_response_length, + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + + self.dataflow_cfg = DataFlowConfig( + max_concurrent=self.max_concurrent, + global_batch_size=self.global_batch_size, + max_retry_times=self.max_retry_times + ) + train_dataset_cfg = [{ + "dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH), + "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), + }] + dataloader_cfg = DataloaderConfig( + collator='fake_collator', + pack_level='none', + group_by_length=False, + ) + self.replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + ) + + def _run_mock_test(self, mock_controller_cls, error_name: str): + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) + self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) + self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) + + # 运行一个批次 + completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3)) + + status = ray.get(self.test_dataflow.get_replaybuffer_status.remote()) + print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}") + self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") + self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") + self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected {self.global_batch_size} rollouts to be interrupted.") + + + def test_rollout_with_timeout_mock(self): + self._run_mock_test(MockTimeoutRolloutController, "timeout") + + def test_rollout_with_request_error_mock(self): + self._run_mock_test(MockRequestErrorRolloutController, "request error") + + def test_rollout_with_client_error_mock(self): + self._run_mock_test(MockClientErrorRolloutController, "client error") + + def test_rollout_with_server_error_mock(self): + self._run_mock_test(MockServerErrorRolloutController, "server error") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index 431abba2a..98f2ff563 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -114,29 +114,7 @@ def setUp(self): def tearDown(self): ray.shutdown() - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_dataflow_with_failed_request(self): - failed_dataflow_cfg = DataFlowConfig( - env="test", - max_concurrent=1, - prompt_repeat_k=2, - global_batch_size=1, - enable_partial_rollout=0, - ) - self.test_env = SingleTurnEnvironment.remote( - "test_env", - self.pg, - rollout_cfg=self.rollout_cfg, - ) - self.test_flow = DataFlow.remote("test_env", - failed_dataflow_cfg, - self.replay_buffer_cfg, - self.test_env - ) - sample_params = SampleParams(temperature=2.5) # invalid temperature to trigger error - with self.assertRaises(AssertionError): - ray.get(self.test_flow.run.remote(num=1, sample_params=sample_params), timeout=300) - + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_lmdeploy_generate(self): sample_params = SampleParams(temperature=0.0) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index 545ed4cc4..174f7bedc 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -196,6 +196,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte self.logger.warning( f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request." ) + self.logger.debug(f"Worker task skipped successfully for {group_data_items[0].uid.action_id}.") return # step 3: filter @@ -258,14 +259,14 @@ async def concurrent_task_runner(self): if result[0].extra_info.retry_times < self.config.max_retry_times: # If the retry count is less than max_retry_times, retry the task self.logger.info( - f"Retrying task for {result[0].data}. Retry count: {result[0].extra_info.retry_times}" + f"Retrying task for {result[0].uid.action_id}. Retry count: {result[0].extra_info.retry_times}" ) retry_task = create_task(self.worker_task(group_samples_for_retry=result)) pending_tasks.add(retry_task) else: self.failed_samples_count += 1 self.logger.error( - f"Max retry reached for {result[0].data}. Not retrying. Current failed count: {self.failed_samples_count}" + f"Max retry reached for {result[0].uid.action_id}. Not retrying. Current failed count: {self.failed_samples_count}" ) self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote()) waiting_tasks = pending_tasks diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py index c23607549..e00787b14 100644 --- a/xtuner/v1/ray/environment/base_env.py +++ b/xtuner/v1/ray/environment/base_env.py @@ -23,12 +23,25 @@ class BaseEnvironment(ABC): """ def __init__( - self, environment: str, rollout_pg: Any, rollout_cfg: Any, judger_pg: Any = None, judger_cfg: Any = None + self, + environment: str, + rollout_pg: Any, + rollout_cfg: Any, + judger_pg: Any = None, + judger_cfg: Any = None, + rollout_controller=None, + judger_controller=None, ): judger_pg = judger_pg if judger_pg else rollout_pg self.environment = environment - self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg) - self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg) + if rollout_controller: + self.rollout_controller = rollout_controller + else: + self.rollout_controller = self.init_rollout_controller(rollout_cfg, rollout_pg) + if judger_controller: + self.judger_controller = judger_controller + else: + self.judger_controller = self.init_judger_controller(judger_cfg, judger_pg) def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): """Initializes the rollout controller with the appropriate worker diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index 79d98ee37..ad27e3b45 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -1,5 +1,6 @@ import asyncio import os +from pathlib import Path from typing import List import ray @@ -31,11 +32,27 @@ class SingleTurnEnvironment(BaseEnvironment): judger_cfg (optional): Configuration for the judger controller. Defaults to None. """ - def __init__(self, environment: str, rollout_pg, rollout_cfg=None, judger_pg=None, judger_cfg=None): - super().__init__(environment, rollout_pg, rollout_cfg, judger_pg, judger_cfg) - worker_log_dir = rollout_cfg.worker_log_dir if rollout_cfg else judger_cfg.worker_log_dir + def __init__( + self, + environment: str, + rollout_pg, + rollout_cfg=None, + judger_pg=None, + judger_cfg=None, + rollout_controller=None, + judger_controller=None, + ): + super().__init__( + environment, rollout_pg, rollout_cfg, judger_pg, judger_cfg, rollout_controller, judger_controller + ) + if rollout_cfg: + worker_log_dir = rollout_cfg.worker_log_dir + elif judger_cfg: + worker_log_dir = judger_cfg.worker_log_dir + else: + worker_log_dir = Path.cwd() / "work_dir" self.logger = get_logger(log_dir=worker_log_dir, tag="SingleTurnEnv") - if rollout_cfg.enable_return_routed_experts: + if rollout_cfg and rollout_cfg.enable_return_routed_experts: self.logger.info("!!! Enable `return routed experts` in rollout controller. !!!") self.rollout_timeout = rollout_cfg.rollout_timeout if rollout_cfg else 1200.0 self.judger_timeout = judger_cfg.judger_timeout if judger_cfg else 1200.0 diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py index 5a89bdefe..ffe3a432f 100644 --- a/xtuner/v1/ray/evaluator.py +++ b/xtuner/v1/ray/evaluator.py @@ -250,7 +250,7 @@ async def concurrent_eval_task_runner(self): retry_task = create_task(self.eval_worker_task(result)) pending_tasks.add(retry_task) else: - self.logger.error(f"Max retry reached for {result.data}. Not retrying.") + self.logger.error(f"Max retry reached for {result.uid.action_id}. Not retrying.") self.failed_samples_count += 1 waiting_tasks = pending_tasks diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 4758fba19..64870416e 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -79,7 +79,6 @@ async def get_worker(self, session_id: int) -> Any: return worker[0] -@ray.remote(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) class RolloutController: """Controller for managing and coordinating multiple RolloutWorker actors.""" @@ -325,18 +324,8 @@ async def rollout( extra_info=extra_info, ) try: - response, http_result = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout) - if http_result.is_success: - return response - elif http_result.is_retryable or http_result.is_server_error: - response.finish_reason = "failed" - return response - elif http_result.is_client_error: - response.finish_reason = "skipped" - return response - else: # unknown error - raise RuntimeError("Unknown error occurred during rollout. Error message: ", http_result.error_message) - + response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout) + return response except asyncio.TimeoutError: self.logger.error("Get response from rollout worker timeout and return the failed response.") failed_response = RLRolloutResponseItem( diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py index ea29b790c..23332f05d 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/ray/rollout/lmdeploy.py @@ -36,7 +36,6 @@ def run_lmdeploy_server_wrapper(lmdeploy_config_namespace: Namespace): serve(**lmdeploy_serve_kwargs) -@ray.remote class LMDeployWorker(RolloutWorker): """A Ray actor that runs a text generation server using LMDeploy.""" diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 6921b2271..614f087ca 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -356,19 +356,26 @@ async def rollout_task( return rollout_response, http_result else: if http_result.is_retryable: + failed_rollout_response.finish_reason = "failed" self.logger.warning(f"Retryable error occurred during rollout request {uid} to {http_result.url}") - return failed_rollout_response, http_result + return failed_rollout_response elif http_result.is_server_error: + failed_rollout_response.finish_reason = "failed" self.logger.error( f"Server error during rollout request {uid} to {http_result.url}, please check the server logs." ) http_result.url = self.server_url - return failed_rollout_response, http_result - else: # http_result.is_client_error: + return failed_rollout_response + elif http_result.is_client_error: + failed_rollout_response.finish_reason = "skipped" self.logger.error( f"Client error during rollout request {uid} to {http_result.url} and skip this request." ) - return failed_rollout_response, http_result + return failed_rollout_response + else: + raise RuntimeError( + f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" + ) async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: last_trajectory = "" diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index a3a256ab4..247380eca 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -3,12 +3,17 @@ import time from typing import Any, Dict, List +import httpx +import ray import requests import uvicorn from fastapi import FastAPI from pydantic import BaseModel, ConfigDict, Field from xtuner.v1.ray.judger.native import NativeJudger +from xtuner.v1.ray.rollout import LMDeployWorker + +from .httpx_utils import HttpRequestErrorType, HttpRequestResult app = FastAPI() @@ -32,6 +37,70 @@ def get_eos_token(model_path: str) -> int | List[int]: return eos_token_id +@ray.remote +class MockTimeoutRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.TimeoutException("Mocked timeout error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +@ray.remote +class MockRequestErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.RequestError("Mocked httpx request error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +@ray.remote +class MockClientErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(400, request=req) + raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +@ray.remote +class MockServerErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(500, request=req) + raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + class JudgeRequest(BaseModel): response: str label: str From 96348101005b237342e5837fcc828946108a6977 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 20:55:58 +0800 Subject: [PATCH 03/18] fix ut --- tests/ray/test_mock_rollout.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index 460d5d922..e99f9c1e7 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -111,10 +111,11 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected {self.global_batch_size} rollouts to be interrupted.") - + ray.get(self.test_env.shutdown.remote(), timeout=300) def test_rollout_with_timeout_mock(self): self._run_mock_test(MockTimeoutRolloutController, "timeout") + def test_rollout_with_request_error_mock(self): self._run_mock_test(MockRequestErrorRolloutController, "request error") From 8ba436a083ab5c1f855518875382b0ac5da17a6c Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 20:57:39 +0800 Subject: [PATCH 04/18] fix --- tests/ray/test_mock_rollout.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index e99f9c1e7..85e6ec658 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -98,6 +98,9 @@ def setUp(self): tokenizer=tokenizer, ) + def tearDown(self): + ray.shutdown() + def _run_mock_test(self, mock_controller_cls, error_name: str): rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) @@ -111,8 +114,7 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected {self.global_batch_size} rollouts to be interrupted.") - ray.get(self.test_env.shutdown.remote(), timeout=300) - + def test_rollout_with_timeout_mock(self): self._run_mock_test(MockTimeoutRolloutController, "timeout") From b84fb4973275c64157694b182b383d0bf468b48e Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 21:17:36 +0800 Subject: [PATCH 05/18] fix ut --- tests/ray/test_mock_rollout.py | 5 ++--- xtuner/v1/data_proto/rl_data.py | 3 ++- xtuner/v1/ray/dataflow/flow.py | 9 ++++++++- xtuner/v1/ray/rollout/controller.py | 5 +++++ xtuner/v1/ray/rollout/worker.py | 1 + xtuner/v1/utils/rl_test_utils.py | 8 ++++---- 6 files changed, 22 insertions(+), 9 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index 85e6ec658..c5b1091ae 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -75,6 +75,7 @@ def setUp(self): tokenizer_path=MODEL_PATH, tensor_parallel_size=1, context_length=self.max_prompt_length + self.max_response_length, + max_retry_per_worker=2 ) tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) @@ -106,7 +107,6 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) - # 运行一个批次 completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3)) status = ray.get(self.test_dataflow.get_replaybuffer_status.remote()) @@ -114,11 +114,10 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected {self.global_batch_size} rollouts to be interrupted.") - + def test_rollout_with_timeout_mock(self): self._run_mock_test(MockTimeoutRolloutController, "timeout") - def test_rollout_with_request_error_mock(self): self._run_mock_test(MockRequestErrorRolloutController, "request error") diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index c723d0f82..c9efe0f4f 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -153,7 +153,8 @@ def check_dataflow_item(group_data_items): # 如果存在abort的状态,相当于跳过检查,下次会重新rollout is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items) - if is_abort: + is_skipped = all(item.env.rollout.finish_reason == "skipped" for item in group_data_items) + if is_abort or is_skipped: return True no_failures = all(item.env.rollout.finish_reason != "failed" for item in group_data_items) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index 174f7bedc..e2901d3f6 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -115,6 +115,7 @@ def __init__( self.finished_samples_count = 0 self.unfinished_samples_count = 0 self.failed_samples_count = 0 + self.input_error_sample_count = 0 self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow") self.target_batch_size = self.config.global_batch_size self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") @@ -197,6 +198,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request." ) self.logger.debug(f"Worker task skipped successfully for {group_data_items[0].uid.action_id}.") + self.input_error_sample_count += 1 return # step 3: filter @@ -235,6 +237,7 @@ async def concurrent_task_runner(self): while ( self.finished_samples_count < self.target_batch_size and self.failed_samples_count < self.target_batch_size + and self.input_error_sample_count < self.target_batch_size ): if self.finished_samples_count >= next_update_threshold: pbar.n = self.finished_samples_count @@ -276,7 +279,10 @@ async def concurrent_task_runner(self): if self.finished_samples_count >= self.target_batch_size: self.logger.info("Target batch size reached. Pausing env controller.") - if self.failed_samples_count >= self.target_batch_size: + if ( + self.failed_samples_count >= self.target_batch_size + or self.input_error_sample_count >= self.target_batch_size + ): self.logger.info("Max failed samples reached. Pausing env controller.") # NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()` @@ -350,6 +356,7 @@ async def run( self.finished_samples_count = 0 self.unfinished_samples_count = 0 self.failed_samples_count = 0 + self.input_error_sample_count = 0 self.target_batch_size = num if num and num > 0 else self.config.global_batch_size self.logger.info(f"Start generate dataflow and target batch size set to {self.target_batch_size}.") self.sample_params = sample_params if sample_params else self.config.sample_params diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 64870416e..8ba4a96e8 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -269,6 +269,7 @@ def deactivate_worker_by_url(self, url): f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet." ) return + self.logger.error(f"Deactivating rollout worker {url} due to repeated failures.") inactive_workers = self.active_url_to_workers.get(url) self.active_workers_to_status[inactive_workers] = False @@ -325,6 +326,10 @@ async def rollout( ) try: response = await asyncio.wait_for(response_ref, timeout=self.config.rollout_timeout) + if response.extra_info and "url" in response.extra_info: + url = response.extra_info["url"] + if response.finish_reason == "failed": + self.deactivate_worker_by_url(url) return response except asyncio.TimeoutError: self.logger.error("Get response from rollout worker timeout and return the failed response.") diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 614f087ca..7ca47de0a 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -361,6 +361,7 @@ async def rollout_task( return failed_rollout_response elif http_result.is_server_error: failed_rollout_response.finish_reason = "failed" + failed_rollout_response.extra_info = {"url": self.server_url} self.logger.error( f"Server error during rollout request {uid} to {http_result.url}, please check the server logs." ) diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index 247380eca..158795b15 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -45,7 +45,7 @@ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: except Exception as e: error_type = HttpRequestErrorType.from_exception(e) result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") return result def launch_server(self): @@ -60,7 +60,7 @@ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: except Exception as e: error_type = HttpRequestErrorType.from_exception(e) result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") return result def launch_server(self): @@ -77,7 +77,7 @@ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: except Exception as e: error_type = HttpRequestErrorType.from_exception(e) result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") return result def launch_server(self): @@ -94,7 +94,7 @@ async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: except Exception as e: error_type = HttpRequestErrorType.from_exception(e) result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked exception: {e.__class__.__name__}") + self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") return result def launch_server(self): From 4ca67982582372c86f3a3ab7d321a337d398c797 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 21:46:43 +0800 Subject: [PATCH 06/18] fix --- tests/ray/test_mock_rollout.py | 2 +- tests/ray/test_rollout.py | 6 ++--- xtuner/v1/ray/environment/base_env.py | 4 +++- xtuner/v1/ray/rollout/controller.py | 6 ++--- xtuner/v1/ray/rollout/worker.py | 33 ++++++++------------------- xtuner/v1/utils/httpx_utils.py | 17 ++++++++++++++ xtuner/v1/utils/rl_test_utils.py | 5 ---- 7 files changed, 36 insertions(+), 37 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index c5b1091ae..7caf1ce86 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -103,7 +103,7 @@ def tearDown(self): ray.shutdown() def _run_mock_test(self, mock_controller_cls, error_name: str): - rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) + rollout_controller = ray.remote(mock_controller_cls).remote(self.rollout_cfg, self.pg) self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) diff --git a/tests/ray/test_rollout.py b/tests/ray/test_rollout.py index 98f2ff563..e10c6225a 100644 --- a/tests/ray/test_rollout.py +++ b/tests/ray/test_rollout.py @@ -118,7 +118,7 @@ def tearDown(self): @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_lmdeploy_generate(self): sample_params = SampleParams(temperature=0.0) - rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] + rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) self.assertEqual(res1.finish_reason, "stop") @@ -186,7 +186,7 @@ def test_lmdeploy_turbomind_generate(self): from xtuner.v1.ray.rollout import LMDeployWorker self.rollout_cfg.extra_rollout_config["lmdeploy_backend"] = "turbomind" sample_params = SampleParams(temperature=0.0) - rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] + rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}") @@ -197,7 +197,7 @@ def test_sglang_generate(self): from xtuner.v1.ray.rollout import SGLangWorker self.rollout_cfg.launch_server_method="multiprocessing" sample_params = SampleParams(temperature=0.0) - rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] + rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined] res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) self.assertEqual(res1.finish_reason, "stop") print("Response from SGLang infer:", res1) diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py index e00787b14..c8434fe67 100644 --- a/xtuner/v1/ray/environment/base_env.py +++ b/xtuner/v1/ray/environment/base_env.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, List +import ray + from xtuner.v1.data_proto.rl_data import RLDataFlowItem @@ -68,7 +70,7 @@ def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): from xtuner.v1.ray.rollout.controller import RolloutController - rollout_controller = RolloutController.remote(rollout_cfg, placement_group) # type: ignore[attr-defined] + rollout_controller = ray.remote(RolloutController).remote(rollout_cfg, placement_group) # type: ignore[attr-defined] return rollout_controller def init_judger_controller(self, judger_cfg: Any, placement_group: Any): diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 8ba4a96e8..74e4acc52 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -136,15 +136,15 @@ def _get_worker_cls(self): if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": from .lmdeploy import LMDeployWorker - return LMDeployWorker + return ray.remote(LMDeployWorker) elif os.environ.get("XTUNER_USE_VLLM") == "1": from .vllm import vLLMWorker - return vLLMWorker + return ray.remote(vLLMWorker) elif os.environ.get("XTUNER_USE_SGLANG") == "1": from .sglang import SGLangWorker - return SGLangWorker + return ray.remote(SGLangWorker) else: raise NotImplementedError( "Rollout backend is not supported." diff --git a/xtuner/v1/ray/rollout/worker.py b/xtuner/v1/ray/rollout/worker.py index 7ca47de0a..0839c07cc 100644 --- a/xtuner/v1/ray/rollout/worker.py +++ b/xtuner/v1/ray/rollout/worker.py @@ -6,7 +6,7 @@ import traceback import uuid from abc import abstractmethod -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import httpx import ray @@ -21,7 +21,7 @@ from xtuner.v1.ray.base import AutoAcceleratorWorkers, SingleAcceleratorWorker from xtuner.v1.ray.config import RolloutConfig from xtuner.v1.utils import get_logger -from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult, set_rollout_response_status class RolloutWorker(SingleAcceleratorWorker): @@ -315,7 +315,7 @@ async def rollout_task( extra_params: dict, format: str, extra_info: dict, - ) -> Tuple[RLRolloutResponseItem, HttpRequestResult]: + ) -> RLRolloutResponseItem: uid = extra_info.get("action_id", str(uuid.uuid4())) response = None failed_rollout_response = RLRolloutResponseItem(finish_reason="failed") @@ -353,30 +353,15 @@ async def rollout_task( finally: if hasattr(response, "aclose"): await response.aclose() - return rollout_response, http_result + return rollout_response else: - if http_result.is_retryable: - failed_rollout_response.finish_reason = "failed" - self.logger.warning(f"Retryable error occurred during rollout request {uid} to {http_result.url}") - return failed_rollout_response - elif http_result.is_server_error: - failed_rollout_response.finish_reason = "failed" - failed_rollout_response.extra_info = {"url": self.server_url} - self.logger.error( - f"Server error during rollout request {uid} to {http_result.url}, please check the server logs." - ) - http_result.url = self.server_url - return failed_rollout_response - elif http_result.is_client_error: - failed_rollout_response.finish_reason = "skipped" - self.logger.error( - f"Client error during rollout request {uid} to {http_result.url} and skip this request." - ) - return failed_rollout_response - else: + if http_result.is_unknown_error: raise RuntimeError( f"Unexpected error during rollout request {uid} to {http_result.url}: {http_result.exception}" ) + else: + set_rollout_response_status(http_result, failed_rollout_response, self.server_url) + return failed_rollout_response async def _handle_stream_response(self, uid, sample_params, extra_params, response) -> RLRolloutResponseItem: last_trajectory = "" @@ -546,7 +531,7 @@ async def rollout( extra_params: dict = dict(), format: str = "openai", extra_info: dict = dict(), - ) -> Tuple[RLRolloutResponseItem, HttpRequestResult]: + ) -> RLRolloutResponseItem: """Public method to initiate a rollout. Args: diff --git a/xtuner/v1/utils/httpx_utils.py b/xtuner/v1/utils/httpx_utils.py index 68c8fb5f7..2259f312c 100644 --- a/xtuner/v1/utils/httpx_utils.py +++ b/xtuner/v1/utils/httpx_utils.py @@ -5,6 +5,8 @@ import httpx +from xtuner.v1.data_proto.rl_data import RLRolloutResponseItem + class HttpRequestErrorType(IntEnum): """An enumeration for HTTP status codes and client-side request errors. @@ -116,6 +118,10 @@ def is_retryable(self) -> bool: HttpRequestErrorType.REQUEST_ERROR, } + @property + def is_unknown_error(self) -> bool: + return self.error_type == HttpRequestErrorType.UNKNOWN_ERROR + @property def is_client_error(self) -> bool: return 400 <= self.error_type < 500 @@ -123,3 +129,14 @@ def is_client_error(self) -> bool: @property def is_server_error(self) -> bool: return 500 <= self.error_type < 600 + + +def set_rollout_response_status(http_result: HttpRequestResult, response: RLRolloutResponseItem, server_url=None): + if http_result.is_retryable: + response.finish_reason = "failed" + elif http_result.is_client_error: + response.finish_reason = "skipped" + elif http_result.is_server_error: + response.finish_reason = "failed" + if server_url: + response.extra_info = {"url": server_url} diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index 158795b15..dd888cb54 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -4,7 +4,6 @@ from typing import Any, Dict, List import httpx -import ray import requests import uvicorn from fastapi import FastAPI @@ -37,7 +36,6 @@ def get_eos_token(model_path: str) -> int | List[int]: return eos_token_id -@ray.remote class MockTimeoutRolloutWorker(LMDeployWorker): async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: @@ -52,7 +50,6 @@ def launch_server(self): pass # Override -@ray.remote class MockRequestErrorRolloutWorker(LMDeployWorker): async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: @@ -67,7 +64,6 @@ def launch_server(self): pass # Override -@ray.remote class MockClientErrorRolloutWorker(LMDeployWorker): async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: @@ -84,7 +80,6 @@ def launch_server(self): pass # Override -@ray.remote class MockServerErrorRolloutWorker(LMDeployWorker): async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: From d3aec12daf278ae73e768be4844f29ff45a7ab2a Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Thu, 13 Nov 2025 21:50:36 +0800 Subject: [PATCH 07/18] fix --- xtuner/v1/ray/dataflow/flow.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index e2901d3f6..38fa29c49 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -115,7 +115,7 @@ def __init__( self.finished_samples_count = 0 self.unfinished_samples_count = 0 self.failed_samples_count = 0 - self.input_error_sample_count = 0 + self.skipped_sample_count = 0 self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow") self.target_batch_size = self.config.global_batch_size self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}") @@ -198,7 +198,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request." ) self.logger.debug(f"Worker task skipped successfully for {group_data_items[0].uid.action_id}.") - self.input_error_sample_count += 1 + self.skipped_sample_count += 1 return # step 3: filter @@ -237,7 +237,7 @@ async def concurrent_task_runner(self): while ( self.finished_samples_count < self.target_batch_size and self.failed_samples_count < self.target_batch_size - and self.input_error_sample_count < self.target_batch_size + and self.skipped_sample_count < self.target_batch_size ): if self.finished_samples_count >= next_update_threshold: pbar.n = self.finished_samples_count @@ -279,10 +279,7 @@ async def concurrent_task_runner(self): if self.finished_samples_count >= self.target_batch_size: self.logger.info("Target batch size reached. Pausing env controller.") - if ( - self.failed_samples_count >= self.target_batch_size - or self.input_error_sample_count >= self.target_batch_size - ): + if self.failed_samples_count >= self.target_batch_size or self.skipped_sample_count >= self.target_batch_size: self.logger.info("Max failed samples reached. Pausing env controller.") # NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()` @@ -356,7 +353,7 @@ async def run( self.finished_samples_count = 0 self.unfinished_samples_count = 0 self.failed_samples_count = 0 - self.input_error_sample_count = 0 + self.skipped_sample_count = 0 self.target_batch_size = num if num and num > 0 else self.config.global_batch_size self.logger.info(f"Start generate dataflow and target batch size set to {self.target_batch_size}.") self.sample_params = sample_params if sample_params else self.config.sample_params From 0e987006477b258b1c0d0822eec7b72719e18b4f Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 17 Nov 2025 17:52:45 +0800 Subject: [PATCH 08/18] fix ut --- tests/ray/test_evaluator.py | 14 -------------- tests/ray/test_mock_rollout.py | 13 +++++-------- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py index e3718f831..a663cc381 100644 --- a/tests/ray/test_evaluator.py +++ b/tests/ray/test_evaluator.py @@ -107,20 +107,6 @@ def custom_compute_metric(samples): custom_evaluator = Evaluator.remote(custom_evaluator_cfg, self.test_env) custom_correctness = ray.get(custom_evaluator.run.remote()) self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy']) - - @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") - def test_lmdeploy_evaluator_with_failed_response(self): - evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=1, - eval_sample_ratio=1, # generate 5 samples - sample_params=SampleParams(temperature=2.5), # invalid temperature to trigger error - max_retry_times=1, - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - correctness = ray.get(evaluator.run.remote()) - self.assertEqual(len(correctness), 0) if __name__ == '__main__': unittest.main() diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index 7caf1ce86..b627d0cac 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -25,22 +25,19 @@ @ray.remote class MockTimeoutRolloutController(RolloutController): def _get_worker_cls(self): - return MockTimeoutRolloutWorker - + return ray.remote(MockTimeoutRolloutWorker) @ray.remote class MockRequestErrorRolloutController(RolloutController): def _get_worker_cls(self): - return MockRequestErrorRolloutWorker - + return ray.remote(MockRequestErrorRolloutWorker) @ray.remote class MockClientErrorRolloutController(RolloutController): def _get_worker_cls(self): - return MockClientErrorRolloutWorker - + return ray.remote(MockClientErrorRolloutWorker) @ray.remote class MockServerErrorRolloutController(RolloutController): def _get_worker_cls(self): - return MockServerErrorRolloutWorker + return ray.remote(MockServerErrorRolloutWorker) class TestMockRollout(unittest.TestCase): @classmethod @@ -103,7 +100,7 @@ def tearDown(self): ray.shutdown() def _run_mock_test(self, mock_controller_cls, error_name: str): - rollout_controller = ray.remote(mock_controller_cls).remote(self.rollout_cfg, self.pg) + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) From 7fa5c7782b136b1b181b72dd6f821102cf43811b Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 17 Nov 2025 20:03:36 +0800 Subject: [PATCH 09/18] fix comments --- tests/ray/test_mock_rollout.py | 5 +---- xtuner/v1/data_proto/rl_data.py | 2 +- xtuner/v1/utils/httpx_utils.py | 10 ++++++---- xtuner/v1/utils/rl_test_utils.py | 3 ++- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index b627d0cac..a47a71459 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -4,13 +4,10 @@ from transformers import AutoTokenizer import torch from xtuner.v1.ray.config.worker import RolloutConfig -from xtuner.v1.ray.judger.controller import JudgerConfig from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig -from xtuner.v1.data_proto.rl_data import SampleParams from xtuner.v1.ray.environment import SingleTurnEnvironment -from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig -from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets +from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig from xtuner.v1.ray.rollout.controller import RolloutController # 导入 Mock Worker diff --git a/xtuner/v1/data_proto/rl_data.py b/xtuner/v1/data_proto/rl_data.py index c9efe0f4f..bcf7a9f4c 100644 --- a/xtuner/v1/data_proto/rl_data.py +++ b/xtuner/v1/data_proto/rl_data.py @@ -153,7 +153,7 @@ def check_dataflow_item(group_data_items): # 如果存在abort的状态,相当于跳过检查,下次会重新rollout is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items) - is_skipped = all(item.env.rollout.finish_reason == "skipped" for item in group_data_items) + is_skipped = any(item.env.rollout.finish_reason == "skipped" for item in group_data_items) if is_abort or is_skipped: return True diff --git a/xtuner/v1/utils/httpx_utils.py b/xtuner/v1/utils/httpx_utils.py index 2259f312c..4038d8369 100644 --- a/xtuner/v1/utils/httpx_utils.py +++ b/xtuner/v1/utils/httpx_utils.py @@ -16,9 +16,9 @@ class HttpRequestErrorType(IntEnum): an HTTP status code. Example: - if error_code == RequestErrorType.BAD_REQUEST: + if error_code == HttpRequestErrorType.BAD_REQUEST: print("Bad request from server!") - elif error_code == RequestErrorType.TIMEOUT_ERROR: + elif error_code == HttpRequestErrorType.TIMEOUT_ERROR: print("Client-side request timed out!") """ @@ -86,7 +86,7 @@ def __post_init__(self): default_messages = { HttpRequestErrorType.UNKNOWN_ERROR: f"An unknown error {self.exception} occurred, Traceback: {traceback.format_exc()}", HttpRequestErrorType.TIMEOUT_ERROR: "The request timed out.", - HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred occurred. TypeError: {type(self.exception)}", + HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred. TypeError: {type(self.exception)}", HttpRequestErrorType.BAD_REQUEST: f"Bad Request (400): The server could not process the request {self.payload}", HttpRequestErrorType.UNAUTHORIZED: "Unauthorized (401): Authentication failed or is required.", HttpRequestErrorType.FORBIDDEN: "Forbidden (403): Access is denied.", @@ -139,4 +139,6 @@ def set_rollout_response_status(http_result: HttpRequestResult, response: RLRoll elif http_result.is_server_error: response.finish_reason = "failed" if server_url: - response.extra_info = {"url": server_url} + if response.extra_info is None: + response.extra_info = {} + response.extra_info.update({"url": server_url}) diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index dd888cb54..4e7b7f877 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -53,7 +53,8 @@ def launch_server(self): class MockRequestErrorRolloutWorker(LMDeployWorker): async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: try: - raise httpx.RequestError("Mocked httpx request error") + req = httpx.Request("POST", url) + raise httpx.RequestError("Mocked httpx request error", request=req) except Exception as e: error_type = HttpRequestErrorType.from_exception(e) result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) From cbd5a4cf50051b9b2c7ad04c6a4ac973e22b2ab9 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 17 Nov 2025 20:47:32 +0800 Subject: [PATCH 10/18] fix comments --- tests/ray/test_mock_rollout.py | 1 - xtuner/v1/ray/environment/single_turn_env.py | 2 ++ xtuner/v1/ray/rollout/controller.py | 5 +++-- xtuner/v1/utils/httpx_utils.py | 7 +++++-- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index a47a71459..daad23ccb 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -16,7 +16,6 @@ MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] -TEST_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] resource_map = {"npu": "NPU", "cuda": "GPU"} @ray.remote diff --git a/xtuner/v1/ray/environment/single_turn_env.py b/xtuner/v1/ray/environment/single_turn_env.py index ad27e3b45..6bb71a6b2 100644 --- a/xtuner/v1/ray/environment/single_turn_env.py +++ b/xtuner/v1/ray/environment/single_turn_env.py @@ -30,6 +30,8 @@ class SingleTurnEnvironment(BaseEnvironment): judger_pg (Any): The placement group for scheduling judger Ray actors. Defaults to None indicates using the rollout_pg. judger_cfg (optional): Configuration for the judger controller. Defaults to None. + rollout_controller (optional): An instance of the rollout controller. Defaults to None. + judger_controller (optional): An instance of the judger controller. Defaults to None. """ def __init__( diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 74e4acc52..16d577596 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -269,9 +269,10 @@ def deactivate_worker_by_url(self, url): f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet." ) return - self.logger.error(f"Deactivating rollout worker {url} due to repeated failures.") inactive_workers = self.active_url_to_workers.get(url) - self.active_workers_to_status[inactive_workers] = False + if inactive_workers: + self.logger.warning(f"Deactivating rollout worker {url} due to repeated failures.") + self.active_workers_to_status[inactive_workers] = False async def rollout( self, diff --git a/xtuner/v1/utils/httpx_utils.py b/xtuner/v1/utils/httpx_utils.py index 4038d8369..c815661a2 100644 --- a/xtuner/v1/utils/httpx_utils.py +++ b/xtuner/v1/utils/httpx_utils.py @@ -1,3 +1,4 @@ +import copy import traceback from dataclasses import dataclass, field from enum import IntEnum @@ -80,8 +81,10 @@ def __post_init__(self): error_msg is not already set.""" # Only generate a message if one hasn't been provided already. if self.error_msg is None and self.error_type != HttpRequestErrorType.SUCCESS: + log_payload = {} if self.payload is not None and "input_ids" in self.payload: - self.payload["input_ids"] = str(self.payload["input_ids"]) + log_payload = copy.deepcopy(self.payload) + log_payload["input_ids"] = str(log_payload["input_ids"]) default_messages = { HttpRequestErrorType.UNKNOWN_ERROR: f"An unknown error {self.exception} occurred, Traceback: {traceback.format_exc()}", @@ -105,7 +108,7 @@ def __post_init__(self): ) if self.error_type == HttpRequestErrorType.REQUEST_ERROR and self.exception: if hasattr(self.exception, "__cause__") and self.exception.__cause__: - self.error_msg += f"__cause__: {self.exception.__cause__}" + self.error_msg += f" __cause__: {self.exception.__cause__}" @property def is_success(self) -> bool: From bc048155fd6700b851a3eeab778813a12cbf3764 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 17 Nov 2025 21:05:29 +0800 Subject: [PATCH 11/18] fix comments --- tests/ray/test_mock_rollout.py | 8 ++++++-- xtuner/v1/utils/httpx_utils.py | 10 ++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index daad23ccb..c7f74d929 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -106,17 +106,21 @@ def _run_mock_test(self, mock_controller_cls, error_name: str): print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}") self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") - self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected {self.global_batch_size} rollouts to be interrupted.") + self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_rollout_with_timeout_mock(self): self._run_mock_test(MockTimeoutRolloutController, "timeout") - + + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_rollout_with_request_error_mock(self): self._run_mock_test(MockRequestErrorRolloutController, "request error") + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_rollout_with_client_error_mock(self): self._run_mock_test(MockClientErrorRolloutController, "client error") + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") def test_rollout_with_server_error_mock(self): self._run_mock_test(MockServerErrorRolloutController, "server error") diff --git a/xtuner/v1/utils/httpx_utils.py b/xtuner/v1/utils/httpx_utils.py index c815661a2..199beaf86 100644 --- a/xtuner/v1/utils/httpx_utils.py +++ b/xtuner/v1/utils/httpx_utils.py @@ -59,7 +59,7 @@ def from_exception(cls, e: Exception) -> "HttpRequestErrorType": return cls.UNKNOWN_ERROR if isinstance(e, httpx.RequestError): - # This check comes after its subclasses (TimeoutException, HTTPStatusError) + # This check comes after its subclass (TimeoutException) return cls.REQUEST_ERROR # For any other standard Python exception @@ -89,12 +89,12 @@ def __post_init__(self): default_messages = { HttpRequestErrorType.UNKNOWN_ERROR: f"An unknown error {self.exception} occurred, Traceback: {traceback.format_exc()}", HttpRequestErrorType.TIMEOUT_ERROR: "The request timed out.", - HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred. TypeError: {type(self.exception)}", - HttpRequestErrorType.BAD_REQUEST: f"Bad Request (400): The server could not process the request {self.payload}", + HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred. ExceptionType: {type(self.exception)}", + HttpRequestErrorType.BAD_REQUEST: f"Bad Request (400): The server could not process the request {log_payload}", HttpRequestErrorType.UNAUTHORIZED: "Unauthorized (401): Authentication failed or is required.", HttpRequestErrorType.FORBIDDEN: "Forbidden (403): Access is denied.", HttpRequestErrorType.NOT_FOUND: "Not Found (404): The resource was not found.", - HttpRequestErrorType.REQUEST_TIMEOUT: f"Request Timeout (408): The server timed out waiting for the request {self.payload}.", + HttpRequestErrorType.REQUEST_TIMEOUT: f"Request Timeout (408): The server timed out waiting for the request {log_payload}.", HttpRequestErrorType.TOO_MANY_REQUESTS: "Too Many Requests (429): Rate limit exceeded.", HttpRequestErrorType.INTERNAL_SERVER_ERROR: f"Internal Server Error (500) {self.exception} occurred in {self.url}, Traceback: {traceback.format_exc()}", HttpRequestErrorType.BAD_GATEWAY: f"Bad Gateway (502) in {self.url}.", @@ -142,6 +142,4 @@ def set_rollout_response_status(http_result: HttpRequestResult, response: RLRoll elif http_result.is_server_error: response.finish_reason = "failed" if server_url: - if response.extra_info is None: - response.extra_info = {} response.extra_info.update({"url": server_url}) From d422029341f15d7346823328aaf3d4679a94b66c Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 11:15:06 +0800 Subject: [PATCH 12/18] fix comments --- xtuner/v1/ray/base/accelerator.py | 5 ++++- xtuner/v1/ray/dataflow/flow.py | 8 +++++--- xtuner/v1/ray/environment/base_env.py | 7 ++++++- xtuner/v1/ray/rollout/controller.py | 1 + xtuner/v1/ray/rollout/sglang.py | 2 -- xtuner/v1/ray/rollout/vllm.py | 2 -- 6 files changed, 16 insertions(+), 9 deletions(-) diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/ray/base/accelerator.py index 0b07666f3..d1efad5bd 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/ray/base/accelerator.py @@ -404,7 +404,10 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup): rank_bundle_idx_list = [] for rank, bundle_idx in enumerate(sorted_bundle_idxs): worker = worker_cls.options( - placement_group=pg, placement_group_bundle_index=bundle_idx, **pg_options + max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)), + placement_group=pg, + placement_group_bundle_index=bundle_idx, + **pg_options, ).remote(worker_config, rank, master_addr, master_port, world_size, device_type) workers_list.append(worker) rank_bundle_idx_list.append((rank, bundle_idx)) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index 38fa29c49..8b80d10b4 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -236,8 +236,8 @@ async def concurrent_task_runner(self): next_update_threshold = update_step while ( self.finished_samples_count < self.target_batch_size - and self.failed_samples_count < self.target_batch_size - and self.skipped_sample_count < self.target_batch_size + and self.failed_samples_count < self.target_batch_size * self.config.max_retry_times + and self.skipped_sample_count < self.target_batch_size * self.config.max_retry_times ): if self.finished_samples_count >= next_update_threshold: pbar.n = self.finished_samples_count @@ -279,8 +279,10 @@ async def concurrent_task_runner(self): if self.finished_samples_count >= self.target_batch_size: self.logger.info("Target batch size reached. Pausing env controller.") - if self.failed_samples_count >= self.target_batch_size or self.skipped_sample_count >= self.target_batch_size: + if self.failed_samples_count >= self.target_batch_size * self.config.max_retry_times: self.logger.info("Max failed samples reached. Pausing env controller.") + if self.skipped_sample_count >= self.target_batch_size * self.config.max_retry_times: + self.logger.info("Max skipped samples reached. Pausing env controller.") # NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()` # would be queued behind many worker tasks, causing a significant delay. diff --git a/xtuner/v1/ray/environment/base_env.py b/xtuner/v1/ray/environment/base_env.py index c8434fe67..6d2a12189 100644 --- a/xtuner/v1/ray/environment/base_env.py +++ b/xtuner/v1/ray/environment/base_env.py @@ -1,3 +1,4 @@ +import os from abc import ABC, abstractmethod from typing import Any, List @@ -70,7 +71,11 @@ def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any): from xtuner.v1.ray.rollout.controller import RolloutController - rollout_controller = ray.remote(RolloutController).remote(rollout_cfg, placement_group) # type: ignore[attr-defined] + rollout_controller = ( + ray.remote(RolloutController) + .options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) + .remote(rollout_cfg, placement_group) + ) # type: ignore[attr-defined] return rollout_controller def init_judger_controller(self, judger_cfg: Any, placement_group: Any): diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 16d577596..4b2ccdab7 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -331,6 +331,7 @@ async def rollout( url = response.extra_info["url"] if response.finish_reason == "failed": self.deactivate_worker_by_url(url) + response.extra_info.pop("url", None) return response except asyncio.TimeoutError: self.logger.error("Get response from rollout worker timeout and return the failed response.") diff --git a/xtuner/v1/ray/rollout/sglang.py b/xtuner/v1/ray/rollout/sglang.py index 611a2cb54..3ba153465 100644 --- a/xtuner/v1/ray/rollout/sglang.py +++ b/xtuner/v1/ray/rollout/sglang.py @@ -1,7 +1,6 @@ import os from typing import Any, Dict, List, Union -import ray import requests from urllib3.exceptions import NewConnectionError @@ -11,7 +10,6 @@ from .worker import RolloutWorker -@ray.remote(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000))) class SGLangWorker(RolloutWorker): def __init__( self, diff --git a/xtuner/v1/ray/rollout/vllm.py b/xtuner/v1/ray/rollout/vllm.py index e337a9d0c..400db51ae 100644 --- a/xtuner/v1/ray/rollout/vllm.py +++ b/xtuner/v1/ray/rollout/vllm.py @@ -1,7 +1,6 @@ from argparse import Namespace from typing import Any, Dict, List, Union -import ray import uvloop from vllm.entrypoints.openai.api_server import run_server from vllm.entrypoints.openai.cli_args import make_arg_parser @@ -16,7 +15,6 @@ def run_vllm_server_wrapper(server_args): uvloop.run(run_server(server_args)) -@ray.remote class vLLMWorker(RolloutWorker): def __init__( self, From e44e83ed009c23f8ff20c4d9b7fe999f461e58b8 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 16:04:19 +0800 Subject: [PATCH 13/18] fix ci --- tests/ray/test_update_weight.py | 4 ++-- xtuner/v1/ray/base/accelerator.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/ray/test_update_weight.py b/tests/ray/test_update_weight.py index dbff5ee96..0af1aa576 100644 --- a/tests/ray/test_update_weight.py +++ b/tests/ray/test_update_weight.py @@ -109,7 +109,7 @@ def test_lmdeploy_update_weight_and_generate(self): sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) # init rollout_update - rollout_controller = RolloutController.remote( + rollout_controller = ray.remote(RolloutController).remote( self.rollout_cfg, self.pg, ) @@ -129,7 +129,7 @@ def test_lmdeploy_update_weight_and_generate(self): # init rollout_ref self.rollout_cfg.skip_load_weights = False - rollout_controller_ref = RolloutController.remote( + rollout_controller_ref = ray.remote(RolloutController).remote( self.rollout_cfg, self.pg, ) diff --git a/xtuner/v1/ray/base/accelerator.py b/xtuner/v1/ray/base/accelerator.py index d1efad5bd..973c7d34e 100644 --- a/xtuner/v1/ray/base/accelerator.py +++ b/xtuner/v1/ray/base/accelerator.py @@ -404,7 +404,6 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup): rank_bundle_idx_list = [] for rank, bundle_idx in enumerate(sorted_bundle_idxs): worker = worker_cls.options( - max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)), placement_group=pg, placement_group_bundle_index=bundle_idx, **pg_options, From 5c8c03ef87f86c51d7c2f8c3ef15d4dbb989ae00 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 20:41:17 +0800 Subject: [PATCH 14/18] fix --- tests/ray/test_mock_rollout.py | 67 +++++++++++++++++++++++++++++-- xtuner/v1/utils/rl_test_utils.py | 68 +------------------------------- 2 files changed, 65 insertions(+), 70 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index c7f74d929..2d0865f07 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -3,6 +3,7 @@ import ray from transformers import AutoTokenizer import torch +import httpx from xtuner.v1.ray.config.worker import RolloutConfig from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig @@ -10,14 +11,74 @@ from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig from xtuner.v1.ray.rollout.controller import RolloutController -# 导入 Mock Worker -from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker - +from xtuner.v1.utils.rl_test_utils import HttpRequestResult, HttpRequestErrorType MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] resource_map = {"npu": "NPU", "cuda": "GPU"} +try: + from xtuner.v1.ray.rollout import LMDeployWorker +except ImportError: + LMDeployWorker = object # 或者 Any + +class MockTimeoutRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.TimeoutException("Mocked timeout error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + +class MockRequestErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + raise httpx.RequestError("Mocked httpx request error", request=req) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + +class MockClientErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(400, request=req) + raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + +class MockServerErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(500, request=req) + raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + @ray.remote class MockTimeoutRolloutController(RolloutController): def _get_worker_cls(self): diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index 4e7b7f877..6d89d64a9 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -1,26 +1,21 @@ import json import multiprocessing +import os import time from typing import Any, Dict, List -import httpx import requests import uvicorn from fastapi import FastAPI from pydantic import BaseModel, ConfigDict, Field from xtuner.v1.ray.judger.native import NativeJudger -from xtuner.v1.ray.rollout import LMDeployWorker - -from .httpx_utils import HttpRequestErrorType, HttpRequestResult app = FastAPI() def get_eos_token(model_path: str) -> int | List[int]: - import os - from xtuner.v1.utils.logger import get_logger logger = get_logger() @@ -36,67 +31,6 @@ def get_eos_token(model_path: str) -> int | List[int]: return eos_token_id -class MockTimeoutRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - raise httpx.TimeoutException("Mocked timeout error") - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockRequestErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - raise httpx.RequestError("Mocked httpx request error", request=req) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockClientErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(400, request=req) - raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - -class MockServerErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(500, request=req) - raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - - class JudgeRequest(BaseModel): response: str label: str From 2f87ebf9c1ba5224d2897c4a6ab89152c6eb35b0 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 21:28:15 +0800 Subject: [PATCH 15/18] fix --- tests/ray/test_mock_rollout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index 2d0865f07..bc28f7630 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -11,7 +11,7 @@ from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig from xtuner.v1.ray.rollout.controller import RolloutController -from xtuner.v1.utils.rl_test_utils import HttpRequestResult, HttpRequestErrorType +from xtuner.v1.utils.httpx_utils import HttpRequestResult, HttpRequestErrorType MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] From 03fef2c789ca1e3c547b9bdb103db64822f41669 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 22:14:44 +0800 Subject: [PATCH 16/18] fix deactivate rollout worker --- xtuner/v1/ray/rollout/controller.py | 148 ++++++++++++++++++---------- xtuner/v1/rl/base/worker.py | 14 ++- xtuner/v1/train/rl_trainer.py | 3 + 3 files changed, 110 insertions(+), 55 deletions(-) diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 4b2ccdab7..869a672e0 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -4,6 +4,7 @@ import threading import time from collections import OrderedDict +from dataclasses import dataclass from itertools import cycle from typing import Any, Dict, List, Optional, Union from uuid import uuid4 @@ -22,6 +23,15 @@ from .worker import RolloutWorker +@dataclass +class WorkerInfo: + """A data class to hold all state information for a single worker.""" + + actor: RolloutWorker + is_active: bool = True + failure_count: int = 0 + + class SessionRouter: def __init__( self, @@ -38,6 +48,7 @@ def __init__( self._map: OrderedDict[int, tuple[Any, float]] = OrderedDict() self._worker_cycler = cycle(self._workers) self._lock = asyncio.Lock() + self.logger = get_logger() def _now(self) -> float: return time.time() @@ -60,6 +71,11 @@ def _evict_lru_to_capacity(self): while len(self._map) > self._max_sessions: self._map.popitem(last=False) + def update_active_workers(self, worker_status: Dict[Any, bool]): + self._workers = list(worker_status.items()) + self.logger.debug(f"SessionRouter update active workers: {self._workers}") + self._worker_cycler = cycle(self._workers) + async def get_worker(self, session_id: int) -> Any: async with self._lock: self._evict_expired() @@ -103,21 +119,18 @@ def __init__( ) self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController") self.num_workers = 0 + self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo self.worker_server_urls: List[str] = [] self.active_rollout_workers: List[RolloutWorker] = [] - self.active_workers_to_status: Dict[RolloutWorker, bool] = {} - self.active_url_to_workers: Dict[str, RolloutWorker] = {} - self.url_failed_counts: Dict[str, int] = {} self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True) self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group( self._get_worker_cls(), infer_config, placement_group ) - self.engine_mesh_list, self.server_url_dict = self.init_workers() + self.engine_mesh_list, self.worker_server_urls_map = self.init_workers() self.start_api_server() # todo(@duanyanhui): add router to replace native round robin - self.router = SessionRouter(self.active_workers_to_status) + self.router = SessionRouter(self._get_worker_status_for_router()) self.sample_params = SampleParams().dict() - # note: 目前默认使用return_token_ids和return_logprob,并且不使用流式 self.extra_params = dict( RolloutExtraParams( stream=False, @@ -132,6 +145,10 @@ def __init__( ) self.print_params_flag = True + def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]: + """Helper to generate the status dict required by the SessionRouter.""" + return {info.actor: info.is_active for info in self.workers_info.values()} + def _get_worker_cls(self): if os.environ.get("XTUNER_USE_LMDEPLOY") == "1": from .lmdeploy import LMDeployWorker @@ -161,7 +178,7 @@ def _is_port_in_use(self, host: str, port: int) -> bool: except OSError: return True - def _update_active_workers_and_urls(self): + def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map): """Update the list of active rollout workers and their server URLs. When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with @@ -170,13 +187,12 @@ def _update_active_workers_and_urls(self): workers and their corresponding URLs. """ if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node: - return + return active_rollout_workers, worker_server_urls_map else: active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node - self.active_rollout_workers = self.active_rollout_workers[::active_worker_interval] - active_rank = list(self.worker_server_urls_map.keys())[::active_worker_interval] - active_worker_server_urls = list(self.worker_server_urls_map.values())[::active_worker_interval] - self.worker_server_urls_map = dict(zip(active_rank, active_worker_server_urls)) + active_rank = list(worker_server_urls_map.keys())[::active_worker_interval] + active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval] + return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls)) def get_rollout_info(self): """Get information about the current rollout setup. @@ -185,10 +201,12 @@ def get_rollout_info(self): dict: A dictionary containing the engine mesh list, server URL dictionary, and the rollout configuration. """ + worker_server_urls_status = {url: info.is_active for url, info in self.workers_info.items()} return dict( engine_mesh_list=self.engine_mesh_list, - server_url_dict=self.server_url_dict, + server_url_dict=self.worker_server_urls_map, rollout_config=self.config, + worker_server_urls_status=worker_server_urls_status, ) def init_workers(self): @@ -210,13 +228,13 @@ def init_workers(self): """ active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(self.workers)) interval = len(self.workers) // active_servers_count - self.active_rollout_workers = self.workers[::interval] - self.num_workers = len(self.active_rollout_workers) + active_rollout_workers = self.workers[::interval] + self.num_workers = len(active_rollout_workers) set_bundle_idxs_objectref = [] engine_mesh_list = [] activate_worker_idx = 0 - for active_worker in self.active_rollout_workers: + for active_worker in active_rollout_workers: head_rank, _ = self.rank_bundle_idx_list[activate_worker_idx] engine_workers_meta = self.rank_bundle_idx_list[head_rank : head_rank + interval] engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx) @@ -225,24 +243,35 @@ def init_workers(self): activate_worker_idx += interval ray.get(set_bundle_idxs_objectref) # init dist_init_addr for each worker according to parallel settings - init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in self.active_rollout_workers]) # type: ignore[attr-defined] + init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined] dist_init_addrs = self._update_dist_init_addr(nodes_per_engine, init_dist_init_addrs, self.num_gpus_per_engine) # launch rollout servers - self.worker_server_urls_map = dict( - ray.get( - [ - worker.init.remote(dist_init_addrs[i]) # type: ignore[attr-defined] - for i, worker in enumerate(self.active_rollout_workers) - ] - ) + worker_server_urls_map = dict( # rank -> url + ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)]) ) - self._update_active_workers_and_urls() - self.worker_server_urls = list(self.worker_server_urls_map.values()) - self.logger.info(f"Rollout worker server URLs: {self.worker_server_urls}") - self.active_workers_to_status = {worker: True for worker in self.active_rollout_workers} - self.active_url_to_workers = dict(zip(self.worker_server_urls, self.active_rollout_workers)) - self.url_failed_counts = {url: 0 for url in self.worker_server_urls} - return engine_mesh_list, self.worker_server_urls_map + active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map( + active_rollout_workers, worker_server_urls_map + ) + self.workers_info = { + url: WorkerInfo(actor=worker) + for url, worker in zip(worker_server_urls_map.values(), active_rollout_workers) + } + self.logger.info(f"Rollout worker server URLs: {list(self.workers_info.keys())}") + return engine_mesh_list, worker_server_urls_map + + def _deactivate_worker(self, url: str): + """A helper function to deactivate a worker, update all related states, + and shut it down.""" + worker_info = self.workers_info.get(url) + if not worker_info or not worker_info.is_active: + return + + self.logger.warning(f"Deactivating rollout worker {worker_info.actor} with URL {url} due to failures.") + worker_info.is_active = False + self.router.update_active_workers(self._get_worker_status_for_router()) + + ray.get(worker_info.actor.offload.remote()) # type: ignore[attr-defined] + ray.get(worker_info.actor.shutdown.remote()) # type: ignore[attr-defined] def check_active_workers(self): """Check the health of all active rollout workers. @@ -251,28 +280,41 @@ def check_active_workers(self): List[bool]: A list of booleans indicating the health status of each active rollout worker. """ + active_workers = [(url, info) for url, info in self.workers_info.items() if info.is_active] + if not active_workers: + return - active_worker_response = ray.get( - [worker.check_health.remote() for worker in self.active_rollout_workers] # type: ignore[attr-defined] - ) - for idx, status in enumerate(active_worker_response): - if not status: - self.logger.info( - f"Rollout worker {self.worker_server_urls[idx]} is unhealthy. Removing it from active workers." - ) - self.active_workers_to_status[self.active_rollout_workers[idx]] = False - - def deactivate_worker_by_url(self, url): - self.url_failed_counts[url] += 1 - if self.url_failed_counts[url] < self.config.max_retry_per_worker: + urls, infos = zip(*active_workers) + actors = [info.actor for info in infos] + + health_statuses = ray.get([actor.check_health.remote() for actor in actors]) + + count = 0 + for url, is_healthy in zip(urls, health_statuses): + if count == 3: + is_healthy = False + count += 1 + if not is_healthy: + self._deactivate_worker(url) + + def deactivate_worker_by_url(self, url: str): + """Deactivates a worker identified by its URL after it exceeds the + maximum retry count.""" + worker_info = self.workers_info.get(url) + if not worker_info or not worker_info.is_active: + return + + worker_info.failure_count += 1 + if ( + self.config.max_retry_per_worker is not None + and worker_info.failure_count < self.config.max_retry_per_worker + ): self.logger.warning( - f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet." + f"Rollout worker {url} failed {worker_info.failure_count} times, but not deactivated yet." ) return - inactive_workers = self.active_url_to_workers.get(url) - if inactive_workers: - self.logger.warning(f"Deactivating rollout worker {url} due to repeated failures.") - self.active_workers_to_status[inactive_workers] = False + + self._deactivate_worker(url) async def rollout( self, @@ -429,9 +471,11 @@ def _broadcast_to_active_workers(self, method_name: str, block: bool): A list of futures if `block` is False, otherwise a list of results. """ futures = [] - for worker, status in self.active_workers_to_status.items(): - if status: - futures.append(getattr(worker, method_name).remote()) + for info in self.workers_info.values(): + if info.is_active: + futures.append(getattr(info.actor, method_name).remote()) + else: + self.logger.warning(f"Skipping {method_name} for inactive worker {info.actor}.") if not block: return futures diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index f22e9ecf2..32443dffa 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -527,7 +527,11 @@ def onload_optimizer(self): self._engine.put_optimizer_to_device(DEVICE) def update_rollout_info( - self, engine_mesh_list: DeviceMeshRaw, server_url_dict: ServiceUrlMap, rollout_config: RolloutConfig + self, + engine_mesh_list: DeviceMeshRaw, + server_url_dict: ServiceUrlMap, + rollout_config: RolloutConfig, + worker_server_urls_status: Dict[str, bool], ): """Update the rollout information for the training worker.""" tp = rollout_config.tensor_parallel_size @@ -537,6 +541,9 @@ def update_rollout_info( "cpu", mesh=engine_mesh_list, mesh_dim_names=("engine_instance", "engine_parallel") ) self.rollout_url = server_url_dict.get(self.rank, "") + if worker_server_urls_status.get(self.rollout_url, False) is False: + self.logger.error(f"Rollout server url {self.rollout_url} is not available.") + self.rollout_url = None self.rollout_cfg_info["tp"] = tp self.rollout_cfg_info["ep"] = ep self.rollout_cfg_info["api_key"] = rollout_config.api_key @@ -858,7 +865,9 @@ def request_update_params(self, state_dict, finished=False): cpu_mesh = self.rollout_device_mesh["engine_parallel"] cpu_group = cpu_mesh.get_group() head_rank = cpu_mesh.mesh[0].item() - + if self.rollout_url is None: + self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip") + return if self.rollout_cfg_info["backend"] == "pytorch": # TODO(chenchiyu): remove lmdeploy related code from lmdeploy.utils import serialize_state_dict @@ -989,7 +998,6 @@ def request_update_params(self, state_dict, finished=False): if use_flattened_tensor_bucket: data["load_format"] = "flattened_bucket" - response = requests.post( f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data ) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 3f7525153..cc072ac8a 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -426,6 +426,9 @@ def fit(self): with timer("saving and sync_weight", step_timer_dict): ray.get(self._train_controller.offload.remote(target="optimizer")) self._maybe_save_hf() + bind_train_rollout( + train_controller=self._train_controller, env_controller=self._rollout_env_controller + ) ray.get(self._rollout_env_controller.onload_weights.remote()) ray.get(self._train_controller.update_weights.remote()) self.logger.info("Model weights synchronized successfully.") From e00a2674169b212d3fe76312cffa672f36c39d18 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 18 Nov 2025 22:40:37 +0800 Subject: [PATCH 17/18] rm useless code --- xtuner/v1/ray/rollout/controller.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/xtuner/v1/ray/rollout/controller.py b/xtuner/v1/ray/rollout/controller.py index 869a672e0..cb0686821 100644 --- a/xtuner/v1/ray/rollout/controller.py +++ b/xtuner/v1/ray/rollout/controller.py @@ -289,11 +289,7 @@ def check_active_workers(self): health_statuses = ray.get([actor.check_health.remote() for actor in actors]) - count = 0 for url, is_healthy in zip(urls, health_statuses): - if count == 3: - is_healthy = False - count += 1 if not is_healthy: self._deactivate_worker(url) From c240f087b1d7209188cbc092f0b568b8757111eb Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Wed, 19 Nov 2025 12:05:22 +0800 Subject: [PATCH 18/18] fix ci --- tests/ray/test_mock_rollout.py | 73 ++++---------------------------- xtuner/v1/rl/base/worker.py | 8 ++-- xtuner/v1/utils/rl_test_utils.py | 68 +++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 68 deletions(-) diff --git a/tests/ray/test_mock_rollout.py b/tests/ray/test_mock_rollout.py index bc28f7630..44428487c 100644 --- a/tests/ray/test_mock_rollout.py +++ b/tests/ray/test_mock_rollout.py @@ -11,90 +11,33 @@ from xtuner.v1.datasets import RLTokenizeFnConfig from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig from xtuner.v1.ray.rollout.controller import RolloutController -from xtuner.v1.utils.httpx_utils import HttpRequestResult, HttpRequestErrorType +from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] resource_map = {"npu": "NPU", "cuda": "GPU"} - -try: - from xtuner.v1.ray.rollout import LMDeployWorker -except ImportError: - LMDeployWorker = object # 或者 Any - -class MockTimeoutRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - raise httpx.TimeoutException("Mocked timeout error") - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - -class MockRequestErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - raise httpx.RequestError("Mocked httpx request error", request=req) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - -class MockClientErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(400, request=req) - raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - -class MockServerErrorRolloutWorker(LMDeployWorker): - async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: - try: - req = httpx.Request("POST", url) - res = httpx.Response(500, request=req) - raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) - except Exception as e: - error_type = HttpRequestErrorType.from_exception(e) - result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) - self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") - return result - - def launch_server(self): - pass # Override - @ray.remote class MockTimeoutRolloutController(RolloutController): def _get_worker_cls(self): return ray.remote(MockTimeoutRolloutWorker) + @ray.remote class MockRequestErrorRolloutController(RolloutController): def _get_worker_cls(self): return ray.remote(MockRequestErrorRolloutWorker) -@ray.remote + +@ray.remote class MockClientErrorRolloutController(RolloutController): def _get_worker_cls(self): return ray.remote(MockClientErrorRolloutWorker) + @ray.remote class MockServerErrorRolloutController(RolloutController): def _get_worker_cls(self): return ray.remote(MockServerErrorRolloutWorker) + + def deactivate_worker_by_url(self, url): + pass class TestMockRollout(unittest.TestCase): @classmethod diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 32443dffa..d08a37b63 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -540,10 +540,12 @@ def update_rollout_info( self.rollout_device_mesh = DeviceMesh( "cpu", mesh=engine_mesh_list, mesh_dim_names=("engine_instance", "engine_parallel") ) - self.rollout_url = server_url_dict.get(self.rank, "") - if worker_server_urls_status.get(self.rollout_url, False) is False: - self.logger.error(f"Rollout server url {self.rollout_url} is not available.") + rollout_server_url = server_url_dict.get(self.rank, "") + if worker_server_urls_status.get(rollout_server_url, "False") is False: + self.logger.error(f"Rollout server url {rollout_server_url} is not available.") self.rollout_url = None + else: + self.rollout_url = rollout_server_url self.rollout_cfg_info["tp"] = tp self.rollout_cfg_info["ep"] = ep self.rollout_cfg_info["api_key"] = rollout_config.api_key diff --git a/xtuner/v1/utils/rl_test_utils.py b/xtuner/v1/utils/rl_test_utils.py index 6d89d64a9..c02eb604e 100644 --- a/xtuner/v1/utils/rl_test_utils.py +++ b/xtuner/v1/utils/rl_test_utils.py @@ -4,6 +4,7 @@ import time from typing import Any, Dict, List +import httpx import requests import uvicorn from fastapi import FastAPI @@ -11,6 +12,73 @@ from xtuner.v1.ray.judger.native import NativeJudger +# try: +from xtuner.v1.ray.rollout.lmdeploy import LMDeployWorker +from xtuner.v1.utils.httpx_utils import HttpRequestErrorType, HttpRequestResult + + +# except ImportError: +# LMDeployWorker = object +class MockTimeoutRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + raise httpx.TimeoutException("Mocked timeout error") + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked timeout exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +class MockRequestErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + raise httpx.RequestError("Mocked httpx request error", request=req) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked request error exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +class MockClientErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(400, request=req) + raise httpx.HTTPStatusError("Mocked client error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked client exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + + +class MockServerErrorRolloutWorker(LMDeployWorker): + async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult: + try: + req = httpx.Request("POST", url) + res = httpx.Response(500, request=req) + raise httpx.HTTPStatusError("Mocked server error", request=req, response=res) + except Exception as e: + error_type = HttpRequestErrorType.from_exception(e) + result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload) + self.logger.info(f"Caught mocked server exception: {e.__class__.__name__}") + return result + + def launch_server(self): + pass # Override + app = FastAPI()