-
Notifications
You must be signed in to change notification settings - Fork 387
[Refactor] refactor http request error #1259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+591
−180
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
2cb5c1e
[Refactor] refactor http request error
YanhuiDua c011186
add ut
YanhuiDua 9634810
fix ut
YanhuiDua 8ba436a
fix
YanhuiDua b84fb49
fix ut
YanhuiDua 4ca6798
fix
YanhuiDua d3aec12
fix
YanhuiDua 0e98700
fix ut
YanhuiDua 7fa5c77
fix comments
YanhuiDua cbd5a4c
fix comments
YanhuiDua bc04815
fix comments
YanhuiDua d422029
fix comments
YanhuiDua e44e83e
fix ci
YanhuiDua 5c8c03e
fix
YanhuiDua 2f87ebf
fix
YanhuiDua 03fef2c
fix deactivate rollout worker
YanhuiDua e00a267
rm useless code
YanhuiDua c240f08
fix ci
YanhuiDua File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import os | ||
| import unittest | ||
| import ray | ||
| from transformers import AutoTokenizer | ||
| import torch | ||
| import httpx | ||
| from xtuner.v1.ray.config.worker import RolloutConfig | ||
| from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers | ||
| from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, ReplayBufferConfig | ||
| from xtuner.v1.ray.environment import SingleTurnEnvironment | ||
| from xtuner.v1.datasets import RLTokenizeFnConfig | ||
| from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig | ||
| from xtuner.v1.ray.rollout.controller import RolloutController | ||
| from xtuner.v1.utils.rl_test_utils import MockTimeoutRolloutWorker, MockRequestErrorRolloutWorker, MockClientErrorRolloutWorker, MockServerErrorRolloutWorker | ||
|
|
||
| MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] | ||
| TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] | ||
| resource_map = {"npu": "NPU", "cuda": "GPU"} | ||
| @ray.remote | ||
| class MockTimeoutRolloutController(RolloutController): | ||
| def _get_worker_cls(self): | ||
| return ray.remote(MockTimeoutRolloutWorker) | ||
|
|
||
| @ray.remote | ||
| class MockRequestErrorRolloutController(RolloutController): | ||
| def _get_worker_cls(self): | ||
| return ray.remote(MockRequestErrorRolloutWorker) | ||
|
|
||
| @ray.remote | ||
| class MockClientErrorRolloutController(RolloutController): | ||
| def _get_worker_cls(self): | ||
| return ray.remote(MockClientErrorRolloutWorker) | ||
|
|
||
| @ray.remote | ||
| class MockServerErrorRolloutController(RolloutController): | ||
| def _get_worker_cls(self): | ||
| return ray.remote(MockServerErrorRolloutWorker) | ||
|
|
||
| def deactivate_worker_by_url(self, url): | ||
| pass | ||
|
|
||
| class TestMockRollout(unittest.TestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| os.environ["XTUNER_USE_FA3"] = "1" | ||
|
|
||
| @classmethod | ||
| def tearDownClass(cls): | ||
| del os.environ["XTUNER_USE_FA3"] | ||
| ray.shutdown() | ||
|
|
||
| def setUp(self): | ||
| ray.init(num_cpus=80, ignore_reinit_error=True) | ||
| self.global_batch_size = 3 | ||
| self.max_prompt_length = 4096 | ||
| self.max_response_length = 128 | ||
| self.max_concurrent = 3 | ||
| self.max_retry_times = 3 | ||
|
|
||
| self.resources_cfg = AcceleratorResourcesConfig( | ||
| accelerator=resource_map[torch.accelerator.current_accelerator().type], | ||
| num_workers=8, | ||
| num_cpus_per_worker=8, | ||
| cpu_memory_per_worker=16 * 1024**3, # 16 GB | ||
| ) | ||
| self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) | ||
|
|
||
| self.rollout_cfg = RolloutConfig( | ||
| env="test_mock_rollout", | ||
| model_path=MODEL_PATH, | ||
| model_name=os.path.basename(MODEL_PATH).lower(), | ||
| tokenizer_path=MODEL_PATH, | ||
| tensor_parallel_size=1, | ||
| context_length=self.max_prompt_length + self.max_response_length, | ||
| max_retry_per_worker=2 | ||
| ) | ||
| tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) | ||
|
|
||
| self.dataflow_cfg = DataFlowConfig( | ||
| max_concurrent=self.max_concurrent, | ||
| global_batch_size=self.global_batch_size, | ||
| max_retry_times=self.max_retry_times | ||
| ) | ||
| train_dataset_cfg = [{ | ||
| "dataset": DatasetConfig(name="mock_data", anno_path=TRAIN_DATA_PATH), | ||
| "tokenize_fn": RLTokenizeFnConfig(max_length=self.max_prompt_length), | ||
| }] | ||
| dataloader_cfg = DataloaderConfig( | ||
| collator='fake_collator', | ||
| pack_level='none', | ||
| group_by_length=False, | ||
| ) | ||
| self.replay_buffer_cfg = ReplayBufferConfig( | ||
| dataset_cfg=train_dataset_cfg, | ||
| dataloader_cfg=dataloader_cfg, | ||
| tokenizer=tokenizer, | ||
| ) | ||
|
|
||
| def tearDown(self): | ||
| ray.shutdown() | ||
|
|
||
| def _run_mock_test(self, mock_controller_cls, error_name: str): | ||
| rollout_controller = mock_controller_cls.remote(self.rollout_cfg, self.pg) | ||
| self.test_env = SingleTurnEnvironment.remote("env", self.pg, self.rollout_cfg, rollout_controller=rollout_controller) | ||
| self.test_dataflow = DataFlow.remote("dataflow", self.dataflow_cfg, self.replay_buffer_cfg, self.test_env) | ||
|
|
||
| completed_rollouts = ray.get(self.test_dataflow.run.remote(num=3)) | ||
|
|
||
| status = ray.get(self.test_dataflow.get_replaybuffer_status.remote()) | ||
| print(f"[{error_name}] Completed rollouts: {completed_rollouts}, Status: {status}") | ||
| self.assertEqual(len(completed_rollouts[0]), 0, f"[{error_name}] Expected no rollouts to complete successfully.") | ||
| self.assertEqual(status["rollout_finished_count"], 0, f"[{error_name}] Completed count in buffer should be 0.") | ||
| self.assertEqual(status["rollout_paused_count"], 0, f"[{error_name}] Expected no rollouts to be interrupted.") | ||
|
|
||
| @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
| def test_rollout_with_timeout_mock(self): | ||
| self._run_mock_test(MockTimeoutRolloutController, "timeout") | ||
|
|
||
| @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
| def test_rollout_with_request_error_mock(self): | ||
| self._run_mock_test(MockRequestErrorRolloutController, "request error") | ||
|
|
||
| @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
| def test_rollout_with_client_error_mock(self): | ||
| self._run_mock_test(MockClientErrorRolloutController, "client error") | ||
|
|
||
| @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") | ||
| def test_rollout_with_server_error_mock(self): | ||
| self._run_mock_test(MockServerErrorRolloutController, "server error") | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.