Skip to content

Commit fcb40aa

Browse files
committed
fix comments
1 parent 8a0d8bf commit fcb40aa

File tree

5 files changed

+11
-11
lines changed

5 files changed

+11
-11
lines changed

tests/ray/test_mock_rollout.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@
44
from transformers import AutoTokenizer
55
import torch
66
from xtuner.v1.ray.config.worker import RolloutConfig
7-
from xtuner.v1.ray.judger.controller import JudgerConfig
87
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
98
from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig
10-
from xtuner.v1.data_proto.rl_data import SampleParams
119
from xtuner.v1.ray.environment import SingleTurnEnvironment
12-
from xtuner.v1.ray.judger.dapo_math import DapoMathJudgerConfig
13-
from xtuner.v1.datasets import RLTokenizeFnConfig, build_datasets
10+
from xtuner.v1.datasets import RLTokenizeFnConfig
1411
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
1512
from xtuner.v1.ray.rollout.controller import RolloutController
1613
# 导入 Mock Worker

xtuner/v1/data_proto/rl_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def check_dataflow_item(group_data_items):
153153

154154
# 如果存在abort的状态,相当于跳过检查,下次会重新rollout
155155
is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items)
156-
is_skipped = all(item.env.rollout.finish_reason == "skipped" for item in group_data_items)
156+
is_skipped = any(item.env.rollout.finish_reason == "skipped" for item in group_data_items)
157157
if is_abort or is_skipped:
158158
return True
159159

xtuner/v1/ray/rollout/worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ async def rollout_task(
331331
http_result = await self._create_request(
332332
endpoint_url,
333333
openai_prompts,
334-
None,
334+
input_ids,
335335
openai_tools,
336336
tool_choice,
337337
sample_params=sample_params,

xtuner/v1/utils/httpx_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ class HttpRequestErrorType(IntEnum):
1616
an HTTP status code.
1717
1818
Example:
19-
if error_code == RequestErrorType.BAD_REQUEST:
19+
if error_code == HttpRequestErrorType.BAD_REQUEST:
2020
print("Bad request from server!")
21-
elif error_code == RequestErrorType.TIMEOUT_ERROR:
21+
elif error_code == HttpRequestErrorType.TIMEOUT_ERROR:
2222
print("Client-side request timed out!")
2323
"""
2424

@@ -86,7 +86,7 @@ def __post_init__(self):
8686
default_messages = {
8787
HttpRequestErrorType.UNKNOWN_ERROR: f"An unknown error {self.exception} occurred, Traceback: {traceback.format_exc()}",
8888
HttpRequestErrorType.TIMEOUT_ERROR: "The request timed out.",
89-
HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred occurred. TypeError: {type(self.exception)}",
89+
HttpRequestErrorType.REQUEST_ERROR: f"A network request error occurred. TypeError: {type(self.exception)}",
9090
HttpRequestErrorType.BAD_REQUEST: f"Bad Request (400): The server could not process the request {self.payload}",
9191
HttpRequestErrorType.UNAUTHORIZED: "Unauthorized (401): Authentication failed or is required.",
9292
HttpRequestErrorType.FORBIDDEN: "Forbidden (403): Access is denied.",
@@ -139,4 +139,6 @@ def set_rollout_response_status(http_result: HttpRequestResult, response: RLRoll
139139
elif http_result.is_server_error:
140140
response.finish_reason = "failed"
141141
if server_url:
142-
response.extra_info = {"url": server_url}
142+
if response.extra_info is None:
143+
response.extra_info = {}
144+
response.extra_info.update({"url": server_url})

xtuner/v1/utils/rl_test_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def launch_server(self):
5353
class MockRequestErrorRolloutWorker(LMDeployWorker):
5454
async def _safe_post_request(self, url, headers, payload) -> HttpRequestResult:
5555
try:
56-
raise httpx.RequestError("Mocked httpx request error")
56+
req = httpx.Request("POST", url)
57+
raise httpx.RequestError("Mocked httpx request error", request=req)
5758
except Exception as e:
5859
error_type = HttpRequestErrorType.from_exception(e)
5960
result = HttpRequestResult(error_type=error_type, exception=e, url=url, payload=payload)

0 commit comments

Comments
 (0)