Skip to content

Commit 8a0d8bf

Browse files
committed
fix ut
1 parent d0500e7 commit 8a0d8bf

File tree

2 files changed

+5
-22
lines changed

2 files changed

+5
-22
lines changed

tests/ray/test_evaluator.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,6 @@ def custom_compute_metric(samples):
107107
custom_evaluator = Evaluator.remote(custom_evaluator_cfg, self.test_env)
108108
custom_correctness = ray.get(custom_evaluator.run.remote())
109109
self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy'])
110-
111-
@unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled")
112-
def test_lmdeploy_evaluator_with_failed_response(self):
113-
evaluator_cfg = EvaluatorConfig(
114-
dataset_cfg=self.eval_dataset_cfg,
115-
tokenizer=self.tokenizer,
116-
max_concurrent=1,
117-
eval_sample_ratio=1, # generate 5 samples
118-
sample_params=SampleParams(temperature=2.5), # invalid temperature to trigger error
119-
max_retry_times=1,
120-
)
121-
evaluator = Evaluator.remote(evaluator_cfg, self.test_env)
122-
correctness = ray.get(evaluator.run.remote())
123-
self.assertEqual(len(correctness), 0)
124110

125111
if __name__ == '__main__':
126112
unittest.main()

tests/ray/test_mock_rollout.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,19 @@
2525
@ray.remote
2626
class MockTimeoutRolloutController(RolloutController):
2727
def _get_worker_cls(self):
28-
return MockTimeoutRolloutWorker
29-
28+
return ray.remote(MockTimeoutRolloutWorker)
3029
@ray.remote
3130
class MockRequestErrorRolloutController(RolloutController):
3231
def _get_worker_cls(self):
33-
return MockRequestErrorRolloutWorker
34-
32+
return ray.remote(MockRequestErrorRolloutWorker)
3533
@ray.remote
3634
class MockClientErrorRolloutController(RolloutController):
3735
def _get_worker_cls(self):
38-
return MockClientErrorRolloutWorker
39-
36+
return ray.remote(MockClientErrorRolloutWorker)
4037
@ray.remote
4138
class MockServerErrorRolloutController(RolloutController):
4239
def _get_worker_cls(self):
43-
return MockServerErrorRolloutWorker
40+
return ray.remote(MockServerErrorRolloutWorker)
4441

4542
class TestMockRollout(unittest.TestCase):
4643
@classmethod
@@ -103,7 +100,7 @@ def tearDown(self):
103100
ray.shutdown()
104101

105102
def _run_mock_test(self, mock_controller_cls, error_name: str):
106-
rollout_controller = ray.remote(mock_controller_cls).remote(self.rollout_cfg, self.pg)
103+
rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg)
107104
self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller)
108105
self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env)
109106

0 commit comments

Comments
 (0)