44import threading
55import time
66from collections import OrderedDict
7+ from dataclasses import dataclass
78from itertools import cycle
89from typing import Any , Dict , List , Optional , Union
910from uuid import uuid4
2223from .worker import RolloutWorker
2324
2425
26+ @dataclass
27+ class WorkerInfo :
28+ """A data class to hold all state information for a single worker."""
29+
30+ actor : RolloutWorker
31+ is_active : bool = True
32+ failure_count : int = 0
33+
34+
2535class SessionRouter :
2636 def __init__ (
2737 self ,
@@ -38,6 +48,7 @@ def __init__(
3848 self ._map : OrderedDict [int , tuple [Any , float ]] = OrderedDict ()
3949 self ._worker_cycler = cycle (self ._workers )
4050 self ._lock = asyncio .Lock ()
51+ self .logger = get_logger ()
4152
4253 def _now (self ) -> float :
4354 return time .time ()
@@ -60,6 +71,11 @@ def _evict_lru_to_capacity(self):
6071 while len (self ._map ) > self ._max_sessions :
6172 self ._map .popitem (last = False )
6273
74+ def update_active_workers (self , worker_status : Dict [Any , bool ]):
75+ self ._workers = list (worker_status .items ())
76+ self .logger .debug (f"SessionRouter update active workers: { self ._workers } " )
77+ self ._worker_cycler = cycle (self ._workers )
78+
6379 async def get_worker (self , session_id : int ) -> Any :
6480 async with self ._lock :
6581 self ._evict_expired ()
@@ -103,21 +119,18 @@ def __init__(
103119 )
104120 self .logger = get_logger (log_dir = infer_config .worker_log_dir , tag = "RolloutController" )
105121 self .num_workers = 0
122+ self .workers_info : Dict [str , WorkerInfo ] = {} # url -> WorkerInfo
106123 self .worker_server_urls : List [str ] = []
107124 self .active_rollout_workers : List [RolloutWorker ] = []
108- self .active_workers_to_status : Dict [RolloutWorker , bool ] = {}
109- self .active_url_to_workers : Dict [str , RolloutWorker ] = {}
110- self .url_failed_counts : Dict [str , int ] = {}
111125 self .tokenizer = AutoTokenizer .from_pretrained (infer_config .tokenizer_path , trust_remote_code = True )
112126 self .workers , self .rank_bundle_idx_list = AutoAcceleratorWorkers .from_placement_group (
113127 self ._get_worker_cls (), infer_config , placement_group
114128 )
115- self .engine_mesh_list , self .server_url_dict = self .init_workers ()
129+ self .engine_mesh_list , self .worker_server_urls_map = self .init_workers ()
116130 self .start_api_server ()
117131 # todo(@duanyanhui): add router to replace native round robin
118- self .router = SessionRouter (self .active_workers_to_status )
132+ self .router = SessionRouter (self ._get_worker_status_for_router () )
119133 self .sample_params = SampleParams ().dict ()
120- # note: 目前默认使用return_token_ids和return_logprob,并且不使用流式
121134 self .extra_params = dict (
122135 RolloutExtraParams (
123136 stream = False ,
@@ -132,6 +145,10 @@ def __init__(
132145 )
133146 self .print_params_flag = True
134147
148+ def _get_worker_status_for_router (self ) -> Dict [RolloutWorker , bool ]:
149+ """Helper to generate the status dict required by the SessionRouter."""
150+ return {info .actor : info .is_active for info in self .workers_info .values ()}
151+
135152 def _get_worker_cls (self ):
136153 if os .environ .get ("XTUNER_USE_LMDEPLOY" ) == "1" :
137154 from .lmdeploy import LMDeployWorker
@@ -161,7 +178,7 @@ def _is_port_in_use(self, host: str, port: int) -> bool:
161178 except OSError :
162179 return True
163180
164- def _update_active_workers_and_urls (self ):
181+ def _update_active_workers_and_urls_map (self , active_rollout_workers , worker_server_urls_map ):
165182 """Update the list of active rollout workers and their server URLs.
166183
167184 When the inference engine is launched across nodes (rollout_cross_node_comm=True), only the worker with
@@ -170,13 +187,12 @@ def _update_active_workers_and_urls(self):
170187 workers and their corresponding URLs.
171188 """
172189 if self .config .rollout_cross_node_comm or self .num_gpus_per_engine < self .config .gpus_per_node :
173- return
190+ return active_rollout_workers , worker_server_urls_map
174191 else :
175192 active_worker_interval = self .num_gpus_per_engine // self .config .gpus_per_node
176- self .active_rollout_workers = self .active_rollout_workers [::active_worker_interval ]
177- active_rank = list (self .worker_server_urls_map .keys ())[::active_worker_interval ]
178- active_worker_server_urls = list (self .worker_server_urls_map .values ())[::active_worker_interval ]
179- self .worker_server_urls_map = dict (zip (active_rank , active_worker_server_urls ))
193+ active_rank = list (worker_server_urls_map .keys ())[::active_worker_interval ]
194+ active_worker_server_urls = list (worker_server_urls_map .values ())[::active_worker_interval ]
195+ return active_rollout_workers [::active_worker_interval ], dict (zip (active_rank , active_worker_server_urls ))
180196
181197 def get_rollout_info (self ):
182198 """Get information about the current rollout setup.
@@ -185,10 +201,12 @@ def get_rollout_info(self):
185201 dict: A dictionary containing the engine mesh list, server URL
186202 dictionary, and the rollout configuration.
187203 """
204+ worker_server_urls_status = {url : info .is_active for url , info in self .workers_info .items ()}
188205 return dict (
189206 engine_mesh_list = self .engine_mesh_list ,
190- server_url_dict = self .server_url_dict ,
207+ server_url_dict = self .worker_server_urls_map ,
191208 rollout_config = self .config ,
209+ worker_server_urls_status = worker_server_urls_status ,
192210 )
193211
194212 def init_workers (self ):
@@ -210,13 +228,13 @@ def init_workers(self):
210228 """
211229 active_servers_count , nodes_per_engine = self ._get_active_servers_count (self .config , len (self .workers ))
212230 interval = len (self .workers ) // active_servers_count
213- self . active_rollout_workers = self .workers [::interval ]
214- self .num_workers = len (self . active_rollout_workers )
231+ active_rollout_workers = self .workers [::interval ]
232+ self .num_workers = len (active_rollout_workers )
215233
216234 set_bundle_idxs_objectref = []
217235 engine_mesh_list = []
218236 activate_worker_idx = 0
219- for active_worker in self . active_rollout_workers :
237+ for active_worker in active_rollout_workers :
220238 head_rank , _ = self .rank_bundle_idx_list [activate_worker_idx ]
221239 engine_workers_meta = self .rank_bundle_idx_list [head_rank : head_rank + interval ]
222240 engine_bundle_idxs = [meta [1 ] for meta in engine_workers_meta ] # meta: (rank, bundle_idx)
@@ -225,24 +243,35 @@ def init_workers(self):
225243 activate_worker_idx += interval
226244 ray .get (set_bundle_idxs_objectref )
227245 # init dist_init_addr for each worker according to parallel settings
228- init_dist_init_addrs = ray .get ([worker .init_dist_port .remote () for worker in self . active_rollout_workers ]) # type: ignore[attr-defined]
246+ init_dist_init_addrs = ray .get ([worker .init_dist_port .remote () for worker in active_rollout_workers ]) # type: ignore[attr-defined]
229247 dist_init_addrs = self ._update_dist_init_addr (nodes_per_engine , init_dist_init_addrs , self .num_gpus_per_engine )
230248 # launch rollout servers
231- self .worker_server_urls_map = dict (
232- ray .get (
233- [
234- worker .init .remote (dist_init_addrs [i ]) # type: ignore[attr-defined]
235- for i , worker in enumerate (self .active_rollout_workers )
236- ]
237- )
249+ worker_server_urls_map = dict ( # rank -> url
250+ ray .get ([worker .init .remote (dist_init_addrs [i ]) for i , worker in enumerate (active_rollout_workers )])
238251 )
239- self ._update_active_workers_and_urls ()
240- self .worker_server_urls = list (self .worker_server_urls_map .values ())
241- self .logger .info (f"Rollout worker server URLs: { self .worker_server_urls } " )
242- self .active_workers_to_status = {worker : True for worker in self .active_rollout_workers }
243- self .active_url_to_workers = dict (zip (self .worker_server_urls , self .active_rollout_workers ))
244- self .url_failed_counts = {url : 0 for url in self .worker_server_urls }
245- return engine_mesh_list , self .worker_server_urls_map
252+ active_rollout_workers , worker_server_urls_map = self ._update_active_workers_and_urls_map (
253+ active_rollout_workers , worker_server_urls_map
254+ )
255+ self .workers_info = {
256+ url : WorkerInfo (actor = worker )
257+ for url , worker in zip (worker_server_urls_map .values (), active_rollout_workers )
258+ }
259+ self .logger .info (f"Rollout worker server URLs: { list (self .workers_info .keys ())} " )
260+ return engine_mesh_list , worker_server_urls_map
261+
262+ def _deactivate_worker (self , url : str ):
263+ """A helper function to deactivate a worker, update all related states,
264+ and shut it down."""
265+ worker_info = self .workers_info .get (url )
266+ if not worker_info or not worker_info .is_active :
267+ return
268+
269+ self .logger .warning (f"Deactivating rollout worker { worker_info .actor } with URL { url } due to failures." )
270+ worker_info .is_active = False
271+ self .router .update_active_workers (self ._get_worker_status_for_router ())
272+
273+ ray .get (worker_info .actor .offload .remote ()) # type: ignore[attr-defined]
274+ ray .get (worker_info .actor .shutdown .remote ()) # type: ignore[attr-defined]
246275
247276 def check_active_workers (self ):
248277 """Check the health of all active rollout workers.
@@ -251,28 +280,41 @@ def check_active_workers(self):
251280 List[bool]: A list of booleans indicating the health status of
252281 each active rollout worker.
253282 """
283+ active_workers = [(url , info ) for url , info in self .workers_info .items () if info .is_active ]
284+ if not active_workers :
285+ return
254286
255- active_worker_response = ray .get (
256- [worker .check_health .remote () for worker in self .active_rollout_workers ] # type: ignore[attr-defined]
257- )
258- for idx , status in enumerate (active_worker_response ):
259- if not status :
260- self .logger .info (
261- f"Rollout worker { self .worker_server_urls [idx ]} is unhealthy. Removing it from active workers."
262- )
263- self .active_workers_to_status [self .active_rollout_workers [idx ]] = False
264-
265- def deactivate_worker_by_url (self , url ):
266- self .url_failed_counts [url ] += 1
267- if self .url_failed_counts [url ] < self .config .max_retry_per_worker :
287+ urls , infos = zip (* active_workers )
288+ actors = [info .actor for info in infos ]
289+
290+ health_statuses = ray .get ([actor .check_health .remote () for actor in actors ])
291+
292+ count = 0
293+ for url , is_healthy in zip (urls , health_statuses ):
294+ if count == 3 :
295+ is_healthy = False
296+ count += 1
297+ if not is_healthy :
298+ self ._deactivate_worker (url )
299+
300+ def deactivate_worker_by_url (self , url : str ):
301+ """Deactivates a worker identified by its URL after it exceeds the
302+ maximum retry count."""
303+ worker_info = self .workers_info .get (url )
304+ if not worker_info or not worker_info .is_active :
305+ return
306+
307+ worker_info .failure_count += 1
308+ if (
309+ self .config .max_retry_per_worker is not None
310+ and worker_info .failure_count < self .config .max_retry_per_worker
311+ ):
268312 self .logger .warning (
269- f"Rollout worker { url } failed { self . url_failed_counts [ url ] } times, but not deactivated yet."
313+ f"Rollout worker { url } failed { worker_info . failure_count } times, but not deactivated yet."
270314 )
271315 return
272- inactive_workers = self .active_url_to_workers .get (url )
273- if inactive_workers :
274- self .logger .warning (f"Deactivating rollout worker { url } due to repeated failures." )
275- self .active_workers_to_status [inactive_workers ] = False
316+
317+ self ._deactivate_worker (url )
276318
277319 async def rollout (
278320 self ,
@@ -429,9 +471,11 @@ def _broadcast_to_active_workers(self, method_name: str, block: bool):
429471 A list of futures if `block` is False, otherwise a list of results.
430472 """
431473 futures = []
432- for worker , status in self .active_workers_to_status .items ():
433- if status :
434- futures .append (getattr (worker , method_name ).remote ())
474+ for info in self .workers_info .values ():
475+ if info .is_active :
476+ futures .append (getattr (info .actor , method_name ).remote ())
477+ else :
478+ self .logger .warning (f"Skipping { method_name } for inactive worker { info .actor } ." )
435479
436480 if not block :
437481 return futures
0 commit comments