Skip to content

Commit d422029

Browse files
committed
fix comments
1 parent bc04815 commit d422029

File tree

6 files changed

+16
-9
lines changed

6 files changed

+16
-9
lines changed

xtuner/v1/ray/base/accelerator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup):
404404
rank_bundle_idx_list = []
405405
for rank, bundle_idx in enumerate(sorted_bundle_idxs):
406406
worker = worker_cls.options(
407-
placement_group=pg, placement_group_bundle_index=bundle_idx, **pg_options
407+
max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)),
408+
placement_group=pg,
409+
placement_group_bundle_index=bundle_idx,
410+
**pg_options,
408411
).remote(worker_config, rank, master_addr, master_port, world_size, device_type)
409412
workers_list.append(worker)
410413
rank_bundle_idx_list.append((rank, bundle_idx))

xtuner/v1/ray/dataflow/flow.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,8 @@ async def concurrent_task_runner(self):
236236
next_update_threshold = update_step
237237
while (
238238
self.finished_samples_count < self.target_batch_size
239-
and self.failed_samples_count < self.target_batch_size
240-
and self.skipped_sample_count < self.target_batch_size
239+
and self.failed_samples_count < self.target_batch_size * self.config.max_retry_times
240+
and self.skipped_sample_count < self.target_batch_size * self.config.max_retry_times
241241
):
242242
if self.finished_samples_count >= next_update_threshold:
243243
pbar.n = self.finished_samples_count
@@ -279,8 +279,10 @@ async def concurrent_task_runner(self):
279279

280280
if self.finished_samples_count >= self.target_batch_size:
281281
self.logger.info("Target batch size reached. Pausing env controller.")
282-
if self.failed_samples_count >= self.target_batch_size or self.skipped_sample_count >= self.target_batch_size:
282+
if self.failed_samples_count >= self.target_batch_size * self.config.max_retry_times:
283283
self.logger.info("Max failed samples reached. Pausing env controller.")
284+
if self.skipped_sample_count >= self.target_batch_size * self.config.max_retry_times:
285+
self.logger.info("Max skipped samples reached. Pausing env controller.")
284286

285287
# NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()`
286288
# would be queued behind many worker tasks, causing a significant delay.

xtuner/v1/ray/environment/base_env.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
from abc import ABC, abstractmethod
23
from typing import Any, List
34

@@ -70,7 +71,11 @@ def init_rollout_controller(self, rollout_cfg: Any, placement_group: Any):
7071

7172
from xtuner.v1.ray.rollout.controller import RolloutController
7273

73-
rollout_controller = ray.remote(RolloutController).remote(rollout_cfg, placement_group) # type: ignore[attr-defined]
74+
rollout_controller = (
75+
ray.remote(RolloutController)
76+
.options(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)))
77+
.remote(rollout_cfg, placement_group)
78+
) # type: ignore[attr-defined]
7479
return rollout_controller
7580

7681
def init_judger_controller(self, judger_cfg: Any, placement_group: Any):

xtuner/v1/ray/rollout/controller.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ async def rollout(
331331
url = response.extra_info["url"]
332332
if response.finish_reason == "failed":
333333
self.deactivate_worker_by_url(url)
334+
response.extra_info.pop("url", None)
334335
return response
335336
except asyncio.TimeoutError:
336337
self.logger.error("Get response from rollout worker timeout and return the failed response.")

xtuner/v1/ray/rollout/sglang.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
from typing import Any, Dict, List, Union
33

4-
import ray
54
import requests
65
from urllib3.exceptions import NewConnectionError
76

@@ -11,7 +10,6 @@
1110
from .worker import RolloutWorker
1211

1312

14-
@ray.remote(max_concurrency=int(os.environ.get("RAY_MAX_CONCURRENCY", 1000)))
1513
class SGLangWorker(RolloutWorker):
1614
def __init__(
1715
self,

xtuner/v1/ray/rollout/vllm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from argparse import Namespace
22
from typing import Any, Dict, List, Union
33

4-
import ray
54
import uvloop
65
from vllm.entrypoints.openai.api_server import run_server
76
from vllm.entrypoints.openai.cli_args import make_arg_parser
@@ -16,7 +15,6 @@ def run_vllm_server_wrapper(server_args):
1615
uvloop.run(run_server(server_args))
1716

1817

19-
@ray.remote
2018
class vLLMWorker(RolloutWorker):
2119
def __init__(
2220
self,

0 commit comments

Comments
 (0)