|
25 | 25 | @ray.remote |
26 | 26 | class MockTimeoutRolloutController(RolloutController): |
27 | 27 | def _get_worker_cls(self): |
28 | | - return MockTimeoutRolloutWorker |
29 | | - |
| 28 | + return ray.remote(MockTimeoutRolloutWorker) |
30 | 29 | @ray.remote |
31 | 30 | class MockRequestErrorRolloutController(RolloutController): |
32 | 31 | def _get_worker_cls(self): |
33 | | - return MockRequestErrorRolloutWorker |
34 | | - |
| 32 | + return ray.remote(MockRequestErrorRolloutWorker) |
35 | 33 | @ray.remote |
36 | 34 | class MockClientErrorRolloutController(RolloutController): |
37 | 35 | def _get_worker_cls(self): |
38 | | - return MockClientErrorRolloutWorker |
39 | | - |
| 36 | + return ray.remote(MockClientErrorRolloutWorker) |
40 | 37 | @ray.remote |
41 | 38 | class MockServerErrorRolloutController(RolloutController): |
42 | 39 | def _get_worker_cls(self): |
43 | | - return MockServerErrorRolloutWorker |
| 40 | + return ray.remote(MockServerErrorRolloutWorker) |
44 | 41 |
|
45 | 42 | class TestMockRollout(unittest.TestCase): |
46 | 43 | @classmethod |
@@ -103,7 +100,7 @@ def tearDown(self): |
103 | 100 | ray.shutdown() |
104 | 101 |
|
105 | 102 | def _run_mock_test(self, mock_controller_cls, error_name: str): |
106 | | - rollout_controller = ray.remote(mock_controller_cls).remote(self.rollout_cfg, self.pg) |
| 103 | + rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) |
107 | 104 | self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) |
108 | 105 | self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) |
109 | 106 |
|
|
0 commit comments