Skip to content

Commit 03fef2c

Browse files
committed
fix deactivate rollout worker
1 parent 2f87ebf commit 03fef2c

File tree

3 files changed

+110
-55
lines changed

3 files changed

+110
-55
lines changed

xtuner/v1/ray/rollout/controller.py

Lines changed: 96 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import threading
55
import time
66
from collections import OrderedDict
7+
from dataclasses import dataclass
78
from itertools import cycle
89
from typing import Any, Dict, List, Optional, Union
910
from uuid import uuid4
@@ -22,6 +23,15 @@
2223
from .worker import RolloutWorker
2324

2425

26+
@dataclass
27+
class WorkerInfo:
28+
"""A data class to hold all state information for a single worker."""
29+
30+
actor: RolloutWorker
31+
is_active: bool = True
32+
failure_count: int = 0
33+
34+
2535
class SessionRouter:
2636
def __init__(
2737
self,
@@ -38,6 +48,7 @@ def __init__(
3848
self._map: OrderedDict[int, tuple[Any, float]] = OrderedDict()
3949
self._worker_cycler = cycle(self._workers)
4050
self._lock = asyncio.Lock()
51+
self.logger = get_logger()
4152

4253
def _now(self) -> float:
4354
return time.time()
@@ -60,6 +71,11 @@ def _evict_lru_to_capacity(self):
6071
while len(self._map) > self._max_sessions:
6172
self._map.popitem(last=False)
6273

74+
def update_active_workers(self, worker_status: Dict[Any, bool]):
75+
self._workers = list(worker_status.items())
76+
self.logger.debug(f"SessionRouter update active workers: {self._workers}")
77+
self._worker_cycler = cycle(self._workers)
78+
6379
async def get_worker(self, session_id: int) -> Any:
6480
async with self._lock:
6581
self._evict_expired()
@@ -103,21 +119,18 @@ def __init__(
103119
)
104120
self.logger = get_logger(log_dir=infer_config.worker_log_dir, tag="RolloutController")
105121
self.num_workers = 0
122+
self.workers_info: Dict[str, WorkerInfo] = {} # url -> WorkerInfo
106123
self.worker_server_urls: List[str] = []
107124
self.active_rollout_workers: List[RolloutWorker] = []
108-
self.active_workers_to_status: Dict[RolloutWorker, bool] = {}
109-
self.active_url_to_workers: Dict[str, RolloutWorker] = {}
110-
self.url_failed_counts: Dict[str, int] = {}
111125
self.tokenizer = AutoTokenizer.from_pretrained(infer_config.tokenizer_path, trust_remote_code=True)
112126
self.workers, self.rank_bundle_idx_list = AutoAcceleratorWorkers.from_placement_group(
113127
self._get_worker_cls(), infer_config, placement_group
114128
)
115-
self.engine_mesh_list, self.server_url_dict = self.init_workers()
129+
self.engine_mesh_list, self.worker_server_urls_map = self.init_workers()
116130
self.start_api_server()
117131
# todo(@duanyanhui): add router to replace native round robin
118-
self.router = SessionRouter(self.active_workers_to_status)
132+
self.router = SessionRouter(self._get_worker_status_for_router())
119133
self.sample_params = SampleParams().dict()
120-
# note: 目前默认使用return_token_ids和return_logprob,并且不使用流式
121134
self.extra_params = dict(
122135
RolloutExtraParams(
123136
stream=False,
@@ -132,6 +145,10 @@ def __init__(
132145
)
133146
self.print_params_flag = True
134147

148+
def _get_worker_status_for_router(self) -> Dict[RolloutWorker, bool]:
149+
"""Helper to generate the status dict required by the SessionRouter."""
150+
return {info.actor: info.is_active for info in self.workers_info.values()}
151+
135152
def _get_worker_cls(self):
136153
if os.environ.get("XTUNER_USE_LMDEPLOY") == "1":
137154
from .lmdeploy import LMDeployWorker
@@ -161,7 +178,7 @@ def _is_port_in_use(self, host: str, port: int) -> bool:
161178
except OSError:
162179
return True
163180

164-
def _update_active_workers_and_urls(self):
181+
def _update_active_workers_and_urls_map(self, active_rollout_workers, worker_server_urls_map):
165182
"""Update the list of active rollout workers and their server URLs.
166183
167184
When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with
@@ -170,13 +187,12 @@ def _update_active_workers_and_urls(self):
170187
workers and their corresponding URLs.
171188
"""
172189
if self.config.rollout_cross_node_comm or self.num_gpus_per_engine < self.config.gpus_per_node:
173-
return
190+
return active_rollout_workers, worker_server_urls_map
174191
else:
175192
active_worker_interval = self.num_gpus_per_engine // self.config.gpus_per_node
176-
self.active_rollout_workers = self.active_rollout_workers[::active_worker_interval]
177-
active_rank = list(self.worker_server_urls_map.keys())[::active_worker_interval]
178-
active_worker_server_urls = list(self.worker_server_urls_map.values())[::active_worker_interval]
179-
self.worker_server_urls_map = dict(zip(active_rank, active_worker_server_urls))
193+
active_rank = list(worker_server_urls_map.keys())[::active_worker_interval]
194+
active_worker_server_urls = list(worker_server_urls_map.values())[::active_worker_interval]
195+
return active_rollout_workers[::active_worker_interval], dict(zip(active_rank, active_worker_server_urls))
180196

181197
def get_rollout_info(self):
182198
"""Get information about the current rollout setup.
@@ -185,10 +201,12 @@ def get_rollout_info(self):
185201
dict: A dictionary containing the engine mesh list, server URL
186202
dictionary, and the rollout configuration.
187203
"""
204+
worker_server_urls_status = {url: info.is_active for url, info in self.workers_info.items()}
188205
return dict(
189206
engine_mesh_list=self.engine_mesh_list,
190-
server_url_dict=self.server_url_dict,
207+
server_url_dict=self.worker_server_urls_map,
191208
rollout_config=self.config,
209+
worker_server_urls_status=worker_server_urls_status,
192210
)
193211

194212
def init_workers(self):
@@ -210,13 +228,13 @@ def init_workers(self):
210228
"""
211229
active_servers_count, nodes_per_engine = self._get_active_servers_count(self.config, len(self.workers))
212230
interval = len(self.workers) // active_servers_count
213-
self.active_rollout_workers = self.workers[::interval]
214-
self.num_workers = len(self.active_rollout_workers)
231+
active_rollout_workers = self.workers[::interval]
232+
self.num_workers = len(active_rollout_workers)
215233

216234
set_bundle_idxs_objectref = []
217235
engine_mesh_list = []
218236
activate_worker_idx = 0
219-
for active_worker in self.active_rollout_workers:
237+
for active_worker in active_rollout_workers:
220238
head_rank, _ = self.rank_bundle_idx_list[activate_worker_idx]
221239
engine_workers_meta = self.rank_bundle_idx_list[head_rank : head_rank + interval]
222240
engine_bundle_idxs = [meta[1] for meta in engine_workers_meta] # meta: (rank, bundle_idx)
@@ -225,24 +243,35 @@ def init_workers(self):
225243
activate_worker_idx += interval
226244
ray.get(set_bundle_idxs_objectref)
227245
# init dist_init_addr for each worker according to parallel settings
228-
init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in self.active_rollout_workers]) # type: ignore[attr-defined]
246+
init_dist_init_addrs = ray.get([worker.init_dist_port.remote() for worker in active_rollout_workers]) # type: ignore[attr-defined]
229247
dist_init_addrs = self._update_dist_init_addr(nodes_per_engine, init_dist_init_addrs, self.num_gpus_per_engine)
230248
# launch rollout servers
231-
self.worker_server_urls_map = dict(
232-
ray.get(
233-
[
234-
worker.init.remote(dist_init_addrs[i]) # type: ignore[attr-defined]
235-
for i, worker in enumerate(self.active_rollout_workers)
236-
]
237-
)
249+
worker_server_urls_map = dict( # rank -> url
250+
ray.get([worker.init.remote(dist_init_addrs[i]) for i, worker in enumerate(active_rollout_workers)])
238251
)
239-
self._update_active_workers_and_urls()
240-
self.worker_server_urls = list(self.worker_server_urls_map.values())
241-
self.logger.info(f"Rollout worker server URLs: {self.worker_server_urls}")
242-
self.active_workers_to_status = {worker: True for worker in self.active_rollout_workers}
243-
self.active_url_to_workers = dict(zip(self.worker_server_urls, self.active_rollout_workers))
244-
self.url_failed_counts = {url: 0 for url in self.worker_server_urls}
245-
return engine_mesh_list, self.worker_server_urls_map
252+
active_rollout_workers, worker_server_urls_map = self._update_active_workers_and_urls_map(
253+
active_rollout_workers, worker_server_urls_map
254+
)
255+
self.workers_info = {
256+
url: WorkerInfo(actor=worker)
257+
for url, worker in zip(worker_server_urls_map.values(), active_rollout_workers)
258+
}
259+
self.logger.info(f"Rollout worker server URLs: {list(self.workers_info.keys())}")
260+
return engine_mesh_list, worker_server_urls_map
261+
262+
def _deactivate_worker(self, url: str):
263+
"""A helper function to deactivate a worker, update all related states,
264+
and shut it down."""
265+
worker_info = self.workers_info.get(url)
266+
if not worker_info or not worker_info.is_active:
267+
return
268+
269+
self.logger.warning(f"Deactivating rollout worker {worker_info.actor} with URL {url} due to failures.")
270+
worker_info.is_active = False
271+
self.router.update_active_workers(self._get_worker_status_for_router())
272+
273+
ray.get(worker_info.actor.offload.remote()) # type: ignore[attr-defined]
274+
ray.get(worker_info.actor.shutdown.remote()) # type: ignore[attr-defined]
246275

247276
def check_active_workers(self):
248277
"""Check the health of all active rollout workers.
@@ -251,28 +280,41 @@ def check_active_workers(self):
251280
List[bool]: A list of booleans indicating the health status of
252281
each active rollout worker.
253282
"""
283+
active_workers = [(url, info) for url, info in self.workers_info.items() if info.is_active]
284+
if not active_workers:
285+
return
254286

255-
active_worker_response = ray.get(
256-
[worker.check_health.remote() for worker in self.active_rollout_workers] # type: ignore[attr-defined]
257-
)
258-
for idx, status in enumerate(active_worker_response):
259-
if not status:
260-
self.logger.info(
261-
f"Rollout worker {self.worker_server_urls[idx]} is unhealthy. Removing it from active workers."
262-
)
263-
self.active_workers_to_status[self.active_rollout_workers[idx]] = False
264-
265-
def deactivate_worker_by_url(self, url):
266-
self.url_failed_counts[url] += 1
267-
if self.url_failed_counts[url] < self.config.max_retry_per_worker:
287+
urls, infos = zip(*active_workers)
288+
actors = [info.actor for info in infos]
289+
290+
health_statuses = ray.get([actor.check_health.remote() for actor in actors])
291+
292+
count = 0
293+
for url, is_healthy in zip(urls, health_statuses):
294+
if count == 3:
295+
is_healthy = False
296+
count += 1
297+
if not is_healthy:
298+
self._deactivate_worker(url)
299+
300+
def deactivate_worker_by_url(self, url: str):
301+
"""Deactivates a worker identified by its URL after it exceeds the
302+
maximum retry count."""
303+
worker_info = self.workers_info.get(url)
304+
if not worker_info or not worker_info.is_active:
305+
return
306+
307+
worker_info.failure_count += 1
308+
if (
309+
self.config.max_retry_per_worker is not None
310+
and worker_info.failure_count < self.config.max_retry_per_worker
311+
):
268312
self.logger.warning(
269-
f"Rollout worker {url} failed {self.url_failed_counts[url]} times, but not deactivated yet."
313+
f"Rollout worker {url} failed {worker_info.failure_count} times, but not deactivated yet."
270314
)
271315
return
272-
inactive_workers = self.active_url_to_workers.get(url)
273-
if inactive_workers:
274-
self.logger.warning(f"Deactivating rollout worker {url} due to repeated failures.")
275-
self.active_workers_to_status[inactive_workers] = False
316+
317+
self._deactivate_worker(url)
276318

277319
async def rollout(
278320
self,
@@ -429,9 +471,11 @@ def _broadcast_to_active_workers(self, method_name: str, block: bool):
429471
A list of futures if `block` is False, otherwise a list of results.
430472
"""
431473
futures = []
432-
for worker, status in self.active_workers_to_status.items():
433-
if status:
434-
futures.append(getattr(worker, method_name).remote())
474+
for info in self.workers_info.values():
475+
if info.is_active:
476+
futures.append(getattr(info.actor, method_name).remote())
477+
else:
478+
self.logger.warning(f"Skipping {method_name} for inactive worker {info.actor}.")
435479

436480
if not block:
437481
return futures

xtuner/v1/rl/base/worker.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,11 @@ def onload_optimizer(self):
527527
self._engine.put_optimizer_to_device(DEVICE)
528528

529529
def update_rollout_info(
530-
self, engine_mesh_list: DeviceMeshRaw, server_url_dict: ServiceUrlMap, rollout_config: RolloutConfig
530+
self,
531+
engine_mesh_list: DeviceMeshRaw,
532+
server_url_dict: ServiceUrlMap,
533+
rollout_config: RolloutConfig,
534+
worker_server_urls_status: Dict[str, bool],
531535
):
532536
"""Update the rollout information for the training worker."""
533537
tp = rollout_config.tensor_parallel_size
@@ -537,6 +541,9 @@ def update_rollout_info(
537541
"cpu", mesh=engine_mesh_list, mesh_dim_names=("engine_instance", "engine_parallel")
538542
)
539543
self.rollout_url = server_url_dict.get(self.rank, "")
544+
if worker_server_urls_status.get(self.rollout_url, False) is False:
545+
self.logger.error(f"Rollout server url {self.rollout_url} is not available.")
546+
self.rollout_url = None
540547
self.rollout_cfg_info["tp"] = tp
541548
self.rollout_cfg_info["ep"] = ep
542549
self.rollout_cfg_info["api_key"] = rollout_config.api_key
@@ -858,7 +865,9 @@ def request_update_params(self, state_dict, finished=False):
858865
cpu_mesh = self.rollout_device_mesh["engine_parallel"]
859866
cpu_group = cpu_mesh.get_group()
860867
head_rank = cpu_mesh.mesh[0].item()
861-
868+
if self.rollout_url is None:
869+
self.logger.error(f"rank {self.rank} url in None, cannot update weights and skip")
870+
return
862871
if self.rollout_cfg_info["backend"] == "pytorch":
863872
# TODO(chenchiyu): remove lmdeploy related code
864873
from lmdeploy.utils import serialize_state_dict
@@ -989,7 +998,6 @@ def request_update_params(self, state_dict, finished=False):
989998

990999
if use_flattened_tensor_bucket:
9911000
data["load_format"] = "flattened_bucket"
992-
9931001
response = requests.post(
9941002
f"{self.rollout_url}/{self.endpoints['update_weights']}", headers=headers, json=data
9951003
)

xtuner/v1/train/rl_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ def fit(self):
426426
with timer("saving and sync_weight", step_timer_dict):
427427
ray.get(self._train_controller.offload.remote(target="optimizer"))
428428
self._maybe_save_hf()
429+
bind_train_rollout(
430+
train_controller=self._train_controller, env_controller=self._rollout_env_controller
431+
)
429432
ray.get(self._rollout_env_controller.onload_weights.remote())
430433
ray.get(self._train_controller.update_weights.remote())
431434
self.logger.info("Model weights synchronized successfully.")

0 commit comments

Comments
 (0)