1+ import os
2+ import unittest
3+ import ray
4+ from transformers import AutoTokenizer
5+ import torch
6+ import httpx
7+ from xtuner .v1 .ray .config .worker import RolloutConfig
8+ from xtuner .v1 .ray .base import AcceleratorResourcesConfig , AutoAcceleratorWorkers
9+ from xtuner .v1 .ray .dataflow import DataFlow , DataFlowConfig , ReplayBufferConfig
10+ from xtuner .v1 .ray .environment import SingleTurnEnvironment
11+ from xtuner .v1 .datasets import RLTokenizeFnConfig
12+ from xtuner .v1 .datasets .config import DataloaderConfig , DatasetConfig
13+ from xtuner .v1 .ray .rollout .controller import RolloutController
14+ from xtuner .v1 .utils .rl_test_utils import MockTimeoutRolloutWorker , MockRequestErrorRolloutWorker , MockClientErrorRolloutWorker , MockServerErrorRolloutWorker
15+
16+ MODEL_PATH = os .environ ["ROLLOUT_MODEL_PATH" ]
17+ TRAIN_DATA_PATH = os .environ ["ROLLOUT_DATA_PATH" ]
18+ resource_map = {"npu" : "NPU" , "cuda" : "GPU" }
19+ @ray .remote
20+ class MockTimeoutRolloutController (RolloutController ):
21+ def _get_worker_cls (self ):
22+ return ray .remote (MockTimeoutRolloutWorker )
23+
24+ @ray .remote
25+ class MockRequestErrorRolloutController (RolloutController ):
26+ def _get_worker_cls (self ):
27+ return ray .remote (MockRequestErrorRolloutWorker )
28+
29+ @ray .remote
30+ class MockClientErrorRolloutController (RolloutController ):
31+ def _get_worker_cls (self ):
32+ return ray .remote (MockClientErrorRolloutWorker )
33+
34+ @ray .remote
35+ class MockServerErrorRolloutController (RolloutController ):
36+ def _get_worker_cls (self ):
37+ return ray .remote (MockServerErrorRolloutWorker )
38+
39+ def deactivate_worker_by_url (self , url ):
40+ pass
41+
42+ class TestMockRollout (unittest .TestCase ):
43+ @classmethod
44+ def setUpClass (cls ):
45+ os .environ ["XTUNER_USE_FA3" ] = "1"
46+
47+ @classmethod
48+ def tearDownClass (cls ):
49+ del os .environ ["XTUNER_USE_FA3" ]
50+ ray .shutdown ()
51+
52+ def setUp (self ):
53+ ray .init (num_cpus = 80 , ignore_reinit_error = True )
54+ self .global_batch_size = 3
55+ self .max_prompt_length = 4096
56+ self .max_response_length = 128
57+ self .max_concurrent = 3
58+ self .max_retry_times = 3
59+
60+ self .resources_cfg = AcceleratorResourcesConfig (
61+ accelerator = resource_map [torch .accelerator .current_accelerator ().type ],
62+ num_workers = 8 ,
63+ num_cpus_per_worker = 8 ,
64+ cpu_memory_per_worker = 16 * 1024 ** 3 , # 16 GB
65+ )
66+ self .pg = AutoAcceleratorWorkers .build_placement_group (self .resources_cfg )
67+
68+ self .rollout_cfg = RolloutConfig (
69+ env = "test_mock_rollout" ,
70+ model_path = MODEL_PATH ,
71+ model_name = os .path .basename (MODEL_PATH ).lower (),
72+ tokenizer_path = MODEL_PATH ,
73+ tensor_parallel_size = 1 ,
74+ context_length = self .max_prompt_length + self .max_response_length ,
75+ max_retry_per_worker = 2
76+ )
77+ tokenizer = AutoTokenizer .from_pretrained (MODEL_PATH , trust_remote_code = True )
78+
79+ self .dataflow_cfg = DataFlowConfig (
80+ max_concurrent = self .max_concurrent ,
81+ global_batch_size = self .global_batch_size ,
82+ max_retry_times = self .max_retry_times
83+ )
84+ train_dataset_cfg = [{
85+ "dataset" : DatasetConfig (name = "mock_data" , anno_path = TRAIN_DATA_PATH ),
86+ "tokenize_fn" : RLTokenizeFnConfig (max_length = self .max_prompt_length ),
87+ }]
88+ dataloader_cfg = DataloaderConfig (
89+ collator = 'fake_collator' ,
90+ pack_level = 'none' ,
91+ group_by_length = False ,
92+ )
93+ self .replay_buffer_cfg = ReplayBufferConfig (
94+ dataset_cfg = train_dataset_cfg ,
95+ dataloader_cfg = dataloader_cfg ,
96+ tokenizer = tokenizer ,
97+ )
98+
99+ def tearDown (self ):
100+ ray .shutdown ()
101+
102+ def _run_mock_test (self , mock_controller_cls , error_name : str ):
103+ rollout_controller = mock_controller_cls .remote (self .rollout_cfg , self .pg )
104+ self .test_env = SingleTurnEnvironment .remote ("env" , self .pg , self .rollout_cfg , rollout_controller = rollout_controller )
105+ self .test_dataflow = DataFlow .remote ("dataflow" , self .dataflow_cfg , self .replay_buffer_cfg , self .test_env )
106+
107+ completed_rollouts = ray .get (self .test_dataflow .run .remote (num = 3 ))
108+
109+ status = ray .get (self .test_dataflow .get_replaybuffer_status .remote ())
110+ print (f"[{ error_name } ] Completed rollouts: { completed_rollouts } , Status: { status } " )
111+ self .assertEqual (len (completed_rollouts [0 ]), 0 , f"[{ error_name } ] Expected no rollouts to complete successfully." )
112+ self .assertEqual (status ["rollout_finished_count" ], 0 , f"[{ error_name } ] Completed count in buffer should be 0." )
113+ self .assertEqual (status ["rollout_paused_count" ], 0 , f"[{ error_name } ] Expected no rollouts to be interrupted." )
114+
115+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
116+ def test_rollout_with_timeout_mock (self ):
117+ self ._run_mock_test (MockTimeoutRolloutController , "timeout" )
118+
119+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
120+ def test_rollout_with_request_error_mock (self ):
121+ self ._run_mock_test (MockRequestErrorRolloutController , "request error" )
122+
123+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
124+ def test_rollout_with_client_error_mock (self ):
125+ self ._run_mock_test (MockClientErrorRolloutController , "client error" )
126+
127+ @unittest .skipIf (os .environ .get ("XTUNER_USE_LMDEPLOY" , "0" ) == "0" , "lmdeploy backend is not enabled" )
128+ def test_rollout_with_server_error_mock (self ):
129+ self ._run_mock_test (MockServerErrorRolloutController , "server error" )
130+
131+ if __name__ == "__main__" :
132+ unittest .main ()
0 commit comments