88from tqdm .auto import tqdm
99from typing_extensions import Annotated
1010
11- from xtuner .v1 .data_proto .rl_data import RLDataFlowItem
11+ from xtuner .v1 .data_proto .rl_data import RLDataFlowItem , check_dataflow_item
1212from xtuner .v1 .ray .environment import SingleTurnEnvironment
1313from xtuner .v1 .ray .rollout .controller import SampleParams
1414from xtuner .v1 .ray .utils import create_task
@@ -141,29 +141,35 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
141141 Optional[List[RLDataFlowItem]]: The group of samples if the task
142142 fails and needs to be retried, otherwise None.
143143 """
144- if group_samples_for_retry is not None :
145- for data_item in group_samples_for_retry :
146- data_item .extra_info .retry_times += 1
147-
148- group_data_items = group_samples_for_retry
149144 try :
150- # 该函数中所有的数据结构都是RLDataFlowItem
151145 # step 1: sample
152- with timer ("sample" , self .timer_dict ):
153- group_data_items = await self .replay_buffer .sample .remote ( # type: ignore[attr-defined]
154- self .env ,
155- self .config .enable_partial_rollout ,
156- self .config .prompt_repeat_k ,
157- )
158- self .send_samples_count += 1
159- self .logger .debug (
160- f"[ROLLOUT] Get 1 sample and dataflow have sent { self .send_samples_count } to rollout_controller"
161- )
146+ # TODO(@duanyanhui): More fine-grained control over group data generation:
147+ # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency
148+ # Resend only the failed prompts in a group when retrying worker_task to avoid wasted computation resources."
149+ if group_samples_for_retry is None or len (group_samples_for_retry ) == 0 :
150+ with timer ("sample" , self .timer_dict ):
151+ group_data_items = await self .replay_buffer .sample .remote ( # type: ignore[attr-defined]
152+ self .env ,
153+ self .config .enable_partial_rollout ,
154+ self .config .prompt_repeat_k ,
155+ )
156+ self .send_samples_count += 1
157+ self .logger .debug (
158+ f"[ROLLOUT] Get 1 sample and dataflow have sent { self .send_samples_count } to rollout_controller"
159+ )
160+ else :
161+ group_data_items = group_samples_for_retry
162+ for data_item in group_samples_for_retry :
163+ data_item .extra_info .retry_times += 1
164+
162165 # step 2: env generate
163166 with timer ("generate" , self .timer_dict ):
164167 group_data_items = await self .env_controller .run .remote ( # type: ignore[attr-defined]
165168 group_data_items , sample_params = self .sample_params , extra_params = self .extra_params
166169 )
170+ # 需要在这里处理check_dataflow_item,因为要保留group_data_items的data信息,作为retry的输入
171+ if not check_dataflow_item (group_data_items ):
172+ return group_data_items
167173
168174 # step 3: filter
169175 with timer ("post_process" , self .timer_dict ):
@@ -175,8 +181,6 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
175181
176182 except Exception as e :
177183 self .logger .error (f"Worker task failed with exception: { e } . Returning meta for retry." , exc_info = True )
178- for sample in group_data_items : # type: ignore[union-attr]
179- sample .extra_info .retry_times += 1
180184 return group_data_items
181185
182186 async def concurrent_task_runner (self ):
@@ -204,7 +208,10 @@ async def concurrent_task_runner(self):
204208 with tqdm (total = self .target_batch_size , desc = "rollout_controller for training samples" ) as pbar :
205209 update_step = max (1 , int (self .target_batch_size * 0.1 ))
206210 next_update_threshold = update_step
207- while self .finished_samples_count < self .target_batch_size :
211+ while (
212+ self .finished_samples_count < self .target_batch_size
213+ and self .failed_samples_count < self .target_batch_size
214+ ):
208215 if self .finished_samples_count >= next_update_threshold :
209216 pbar .n = self .finished_samples_count
210217 pbar .refresh ()
@@ -227,27 +234,36 @@ async def concurrent_task_runner(self):
227234 if result is not None :
228235 if result [0 ].extra_info .retry_times < self .config .max_retry_times :
229236 # If the retry count is less than max_retry_times, retry the task
237+ self .logger .info (
238+ f"Retrying task for { result [0 ].data } . Retry count: { result [0 ].extra_info .retry_times } "
239+ )
230240 retry_task = create_task (self .worker_task (group_samples_for_retry = result ))
231241 pending_tasks .add (retry_task )
232242 else :
233- self .logger .error (f"Max retry reached for { result [0 ]['prompt_id' ]} . Not retrying." )
234243 self .failed_samples_count += 1
235-
244+ self .logger .error (
245+ f"Max retry reached for { result [0 ].data } . Not retrying. Current failed count: { self .failed_samples_count } "
246+ )
236247 self .finished_samples_count = ray .get (self .replay_buffer .get_finished_samples .remote ())
237248 waiting_tasks = pending_tasks
238249
239250 pbar .n = self .finished_samples_count
240251 pbar .refresh ()
241252
242- self .logger .info ("Target batch size reached. Pausing env controller." )
253+ if self .finished_samples_count == self .target_batch_size :
254+ self .logger .info ("Target batch size reached. Pausing env controller." )
255+ if self .failed_samples_count == self .target_batch_size :
256+ self .logger .info ("Max failed samples reached. Pausing env controller." )
257+
243258 ray .get (self .env_controller .pause .remote ())
244259
245260 if waiting_tasks :
246261 await asyncio .wait_for (asyncio .gather (* waiting_tasks , return_exceptions = True ), timeout = 10 )
247262
248- self .unfinished_samples_count = ray .get (self .replay_buffer .get_unfinished_samples .remote ())
249- self .logging_replaybuffer_state ()
250- self .logging_timing_perf ()
263+ if self .finished_samples_count == self .target_batch_size :
264+ self .unfinished_samples_count = ray .get (self .replay_buffer .get_unfinished_samples .remote ())
265+ self .logging_replaybuffer_state ()
266+ self .logging_timing_perf ()
251267
252268 async def run (
253269 self ,
0 commit comments