Skip to content
14 changes: 0 additions & 14 deletions tests/ray/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,20 +107,6 @@ def custom_compute_metric(samples):
custom_evaluator = Evaluator.remote(custom_evaluator_cfg, self.test_env)
custom_correctness = ray.get(custom_evaluator.run.remote())
self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy'])

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_evaluator_with_failed_response(self):
evaluator_cfg = EvaluatorConfig(
dataset_cfg=self.eval_dataset_cfg,
tokenizer=self.tokenizer,
max_concurrent=1,
eval_sample_ratio=1, # generate 5 samples
sample_params=SampleParams(temperature=2.5), # invalid temperature to trigger error
max_retry_times=1,
)
evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
correctness = ray.get(evaluator.run.remote())
self.assertEqual(len(correctness), 0)

if __name__ == '__main__':
unittest.main()
132 changes: 132 additions & 0 deletions tests/ray/test_mock_rollout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import os
import unittest
import ray
from transformers import AutoTokenizer
import torch
import httpx
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers
from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig
from xtuner.v1.ray.environment import SingleTurnEnvironment
from xtuner.v1.datasets import RLTokenizeFnConfig
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.ray.rollout.controller import RolloutController
from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker

MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"]
TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"]
resource_map = {"npu": "NPU", "cuda": "GPU"}
@ray.remote
class MockTimeoutRolloutController(RolloutController):
def _get_worker_cls(self):
return ray.remote(MockTimeoutRolloutWorker)

@ray.remote
class MockRequestErrorRolloutController(RolloutController):
def _get_worker_cls(self):
return ray.remote(MockRequestErrorRolloutWorker)

@ray.remote
class MockClientErrorRolloutController(RolloutController):
def _get_worker_cls(self):
return ray.remote(MockClientErrorRolloutWorker)

@ray.remote
class MockServerErrorRolloutController(RolloutController):
def _get_worker_cls(self):
return ray.remote(MockServerErrorRolloutWorker)

def deactivate_worker_by_url(self, url):
pass

class TestMockRollout(unittest.TestCase):
@classmethod
def setUpClass(cls):
os.environ["XTUNER_USE_FA3"] = "1"

@classmethod
def tearDownClass(cls):
del os.environ["XTUNER_USE_FA3"]
ray.shutdown()

def setUp(self):
ray.init(num_cpus=80, ignore_reinit_error=True)
self.global_batch_size = 3
self.max_prompt_length = 4096
self.max_response_length = 128
self.max_concurrent = 3
self.max_retry_times = 3

self.resources_cfg = AcceleratorResourcesConfig(
accelerator=resource_map[torch.accelerator.current_accelerator().type],
num_workers=8,
num_cpus_per_worker=8,
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)
self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg)

self.rollout_cfg = RolloutConfig(
env="test_mock_rollout",
model_path=MODEL_PATH,
model_name=os.path.basename(MODEL_PATH).lower(),
tokenizer_path=MODEL_PATH,
tensor_parallel_size=1,
context_length=self.max_prompt_length + self.max_response_length,
max_retry_per_worker=2
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)

self.dataflow_cfg = DataFlowConfig(
max_concurrent=self.max_concurrent,
global_batch_size=self.global_batch_size,
max_retry_times=self.max_retry_times
)
train_dataset_cfg = [{
"dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH),
"tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length),
}]
dataloader_cfg = DataloaderConfig(
collator='fake_collator',
pack_level='none',
group_by_length=False,
)
self.replay_buffer_cfg = ReplayBufferConfig(
dataset_cfg=train_dataset_cfg,
dataloader_cfg=dataloader_cfg,
tokenizer=tokenizer,
)

def tearDown(self):
ray.shutdown()

def _run_mock_test(self, mock_controller_cls, error_name: str):
rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg)
self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller)
self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env)

completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3))

status = ray.get(self.test_dataflow.get_replaybuffer_status.remote())
print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}")
self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.")
self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.")
self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.")

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_rollout_with_timeout_mock(self):
self._run_mock_test(MockTimeoutRolloutController, "timeout")

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_rollout_with_request_error_mock(self):
self._run_mock_test(MockRequestErrorRolloutController, "request error")

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_rollout_with_client_error_mock(self):
self._run_mock_test(MockClientErrorRolloutController, "client error")

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_rollout_with_server_error_mock(self):
self._run_mock_test(MockServerErrorRolloutController, "server error")

if __name__ == "__main__":
unittest.main()
30 changes: 4 additions & 26 deletions tests/ray/test_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,33 +114,11 @@ def setUp(self):
def tearDown(self):
ray.shutdown()

@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_dataflow_with_failed_request(self):
failed_dataflow_cfg = DataFlowConfig(
env="test",
max_concurrent=1,
prompt_repeat_k=2,
global_batch_size=1,
enable_partial_rollout=0,
)
self.test_env = SingleTurnEnvironment.remote(
"test_env",
self.pg,
rollout_cfg=self.rollout_cfg,
)
self.test_flow = DataFlow.remote("test_env",
failed_dataflow_cfg,
self.replay_buffer_cfg,
self.test_env
)
sample_params = SampleParams(temperature=2.5) # invalid temperature to trigger error
with self.assertRaises(AssertionError):
ray.get(self.test_flow.run.remote(num=1, sample_params=sample_params), timeout=300)


@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
def test_lmdeploy_generate(self):
sample_params = SampleParams(temperature=0.0)
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))

self.assertEqual(res1.finish_reason, "stop")
Expand Down Expand Up @@ -208,7 +186,7 @@ def test_lmdeploy_turbomind_generate(self):
from xtuner.v1.ray.rollout import LMDeployWorker
self.rollout_cfg.extra_rollout_config["lmdeploy_backend"] = "turbomind"
sample_params = SampleParams(temperature=0.0)
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
res2 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
self.assertEqual(res1, res2, f"res1 != res2, res1={res1}, res2={res2}")
Expand All @@ -219,7 +197,7 @@ def test_sglang_generate(self):
from xtuner.v1.ray.rollout import SGLangWorker
self.rollout_cfg.launch_server_method="multiprocessing"
sample_params = SampleParams(temperature=0.0)
rollout_controller = RolloutController.remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
rollout_controller = ray.remote(RolloutController).remote(self.rollout_cfg, self.pg) # type: ignore[attr-defined]
res1 = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params))
self.assertEqual(res1.finish_reason, "stop")
print("Response from SGLang infer:", res1)
Expand Down
4 changes: 2 additions & 2 deletions tests/ray/test_update_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_lmdeploy_update_weight_and_generate(self):
sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1)

# init rollout_update
rollout_controller = RolloutController.remote(
rollout_controller = ray.remote(RolloutController).remote(
self.rollout_cfg,
self.pg,
)
Expand All @@ -129,7 +129,7 @@ def test_lmdeploy_update_weight_and_generate(self):

# init rollout_ref
self.rollout_cfg.skip_load_weights = False
rollout_controller_ref = RolloutController.remote(
rollout_controller_ref = ray.remote(RolloutController).remote(
self.rollout_cfg,
self.pg,
)
Expand Down
5 changes: 3 additions & 2 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class RLRolloutResponseItem(BaseModel):
response: Optional[str] = None
response_ids: Optional[List[int]] = None
num_return_tokens: Optional[int] = None
finish_reason: Optional[str] = None
finish_reason: Optional[str] = None # "stop", "length", "abort", "failed", "skipped"
logprobs: Optional[List[float]] = None
extra_info: Dict[str, Any] = dict()

Expand Down Expand Up @@ -153,7 +153,8 @@ def check_dataflow_item(group_data_items):

# 如果存在abort的状态,相当于跳过检查,下次会重新rollout
is_abort = any(item.env.rollout.finish_reason == "abort" for item in group_data_items)
if is_abort:
is_skipped = any(item.env.rollout.finish_reason == "skipped" for item in group_data_items)
if is_abort or is_skipped:
return True

no_failures = all(item.env.rollout.finish_reason != "failed" for item in group_data_items)
Expand Down
4 changes: 3 additions & 1 deletion xtuner/v1/ray/base/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def from_placement_group(cls, worker_cls, worker_config, pg: PlacementGroup):
rank_bundle_idx_list = []
for rank, bundle_idx in enumerate(sorted_bundle_idxs):
worker = worker_cls.options(
placement_group=pg, placement_group_bundle_index=bundle_idx, **pg_options
placement_group=pg,
placement_group_bundle_index=bundle_idx,
**pg_options,
).remote(worker_config, rank, master_addr, master_port, world_size, device_type)
workers_list.append(worker)
rank_bundle_idx_list.append((rank, bundle_idx))
Expand Down
10 changes: 10 additions & 0 deletions xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,13 @@ class RolloutConfig(BaseModel):
help='Extra configuration for different rollout worker. vllm parameters will start with prefix "vllm", etc.',
),
] = {}
max_retry_per_worker: Annotated[
Optional[int],
Parameter(
group=infer_group,
help="Maximum number of retries per rollout worker before deactivation.",
),
] = None
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"

def model_post_init(self, __context: Any) -> None:
Expand Down Expand Up @@ -259,6 +266,9 @@ def model_post_init(self, __context: Any) -> None:
else:
self.rollout_max_batch_size_per_instance = 128

if self.max_retry_per_worker is None:
self.max_retry_per_worker = self.rollout_max_batch_size_per_instance

self.worker_log_dir.mkdir(parents=True, exist_ok=True)


Expand Down
20 changes: 16 additions & 4 deletions xtuner/v1/ray/dataflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
self.finished_samples_count = 0
self.unfinished_samples_count = 0
self.failed_samples_count = 0
self.skipped_sample_count = 0
self.logger = get_logger(log_dir=self.config.worker_log_dir, tag="DataFlow")
self.target_batch_size = self.config.global_batch_size
self.logger.info(f"DataFlowConfig:\n{self.config.model_dump_json(indent=2)}")
Expand Down Expand Up @@ -192,6 +193,13 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
f"Dataflow item check failed for {group_data_items[0].uid.action_id} response. Returning meta for retry."
)
return group_data_items
if any(item.env.rollout.finish_reason == "skipped" for item in group_data_items):
self.logger.warning(
f"Bad request for {group_data_items[0].uid.action_id} response. Skipping this request."
)
self.logger.debug(f"Worker task skipped successfully for {group_data_items[0].uid.action_id}.")
self.skipped_sample_count += 1
return

# step 3: filter
filtered_group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
Expand Down Expand Up @@ -228,7 +236,8 @@ async def concurrent_task_runner(self):
next_update_threshold = update_step
while (
self.finished_samples_count < self.target_batch_size
and self.failed_samples_count < self.target_batch_size
and self.failed_samples_count < self.target_batch_size * self.config.max_retry_times
and self.skipped_sample_count < self.target_batch_size * self.config.max_retry_times
):
if self.finished_samples_count >= next_update_threshold:
pbar.n = self.finished_samples_count
Expand All @@ -253,14 +262,14 @@ async def concurrent_task_runner(self):
if result[0].extra_info.retry_times < self.config.max_retry_times:
# If the retry count is less than max_retry_times, retry the task
self.logger.info(
f"Retrying task for {result[0].data}. Retry count: {result[0].extra_info.retry_times}"
f"Retrying task for {result[0].uid.action_id}. Retry count: {result[0].extra_info.retry_times}"
)
retry_task = create_task(self.worker_task(group_samples_for_retry=result))
pending_tasks.add(retry_task)
else:
self.failed_samples_count += 1
self.logger.error(
f"Max retry reached for {result[0].data}. Not retrying. Current failed count: {self.failed_samples_count}"
f"Max retry reached for {result[0].uid.action_id}. Not retrying. Current failed count: {self.failed_samples_count}"
)
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
waiting_tasks = pending_tasks
Expand All @@ -270,8 +279,10 @@ async def concurrent_task_runner(self):

if self.finished_samples_count >= self.target_batch_size:
self.logger.info("Target batch size reached. Pausing env controller.")
if self.failed_samples_count >= self.target_batch_size:
if self.failed_samples_count >= self.target_batch_size * self.config.max_retry_times:
self.logger.info("Max failed samples reached. Pausing env controller.")
if self.skipped_sample_count >= self.target_batch_size * self.config.max_retry_times:
self.logger.info("Max skipped samples reached. Pausing env controller.")

# NOTE: Directly send pause requests to rollout workers because calling `rollout_controller.pause()`
# would be queued behind many worker tasks, causing a significant delay.
Expand Down Expand Up @@ -344,6 +355,7 @@ async def run(
self.finished_samples_count = 0
self.unfinished_samples_count = 0
self.failed_samples_count = 0
self.skipped_sample_count = 0
self.target_batch_size = num if num and num > 0 else self.config.global_batch_size
self.logger.info(f"Start generate dataflow and target batch size set to {self.target_batch_size}.")
self.sample_params = sample_params if sample_params else self.config.sample_params
Expand Down
Loading