|
| 1 | +import os |
| 2 | +import unittest |
| 3 | + |
| 4 | +import ray |
| 5 | + |
| 6 | +from xtuner.v1.ray.base import AutoAcceleratorWorkers |
| 7 | +from xtuner.v1.ray.rollout import RolloutController |
| 8 | +from xtuner.v1.data_proto.rl_data import SampleParams |
| 9 | +from xtuner.v1.config import ( |
| 10 | + AdamWConfig, |
| 11 | + FSDPConfig, |
| 12 | + LRConfig, |
| 13 | +) |
| 14 | +from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig |
| 15 | +from xtuner.v1.ray.config.worker import RolloutConfig |
| 16 | +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers |
| 17 | +from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker |
| 18 | +from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig |
| 19 | +from xtuner.v1.model import get_model_config_from_hf |
| 20 | + |
| 21 | +TEST_TEXT_MESSAGES=[{"role": "user", "content": "Hello!"}] |
| 22 | +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] |
| 23 | + |
| 24 | +class TestUpdateWeight(unittest.TestCase): |
| 25 | + def setUp(self): |
| 26 | + os.environ["XTUNER_USE_FA3"] = "1" |
| 27 | + ray.init(num_cpus=80, ignore_reinit_error=True) |
| 28 | + self.model_path = MODEL_PATH |
| 29 | + self.init_config() |
| 30 | + self.pg = AutoAcceleratorWorkers.build_placement_group(self.resources_cfg) |
| 31 | + |
| 32 | + def tearDown(self): |
| 33 | + ray.shutdown() |
| 34 | + del os.environ["XTUNER_USE_FA3"] |
| 35 | + |
| 36 | + def init_config(self): |
| 37 | + self.resources_cfg = AcceleratorResourcesConfig( |
| 38 | + accelerator="GPU", |
| 39 | + num_workers=2, |
| 40 | + num_cpus_per_worker=16, |
| 41 | + cpu_memory_per_worker=16 * 1024 ** 3, # 16 GB |
| 42 | + ) |
| 43 | + self.rollout_cfg = RolloutConfig( |
| 44 | + env="test_rollout", |
| 45 | + model_path=MODEL_PATH, |
| 46 | + model_name=os.path.basename(MODEL_PATH).lower(), |
| 47 | + tokenizer_path=MODEL_PATH, |
| 48 | + rollout_cross_node_comm=False, |
| 49 | + tensor_parallel_size=2, |
| 50 | + expert_parallel_size=1, |
| 51 | + gpus_per_node=8, # gpu: 8, npu: 16 |
| 52 | + dtype="bfloat16", |
| 53 | + skip_load_weights=True, |
| 54 | + context_length=256, |
| 55 | + ) |
| 56 | + |
| 57 | + # training config |
| 58 | + model_cfg = get_model_config_from_hf(model_path=MODEL_PATH) |
| 59 | + if hasattr(model_cfg, 'z_loss_cfg'): |
| 60 | + model_cfg.z_loss_cfg = ZLossConfig() |
| 61 | + if hasattr(model_cfg, 'balancing_loss_cfg'): |
| 62 | + model_cfg.balancing_loss_cfg = BalancingLossConfig() |
| 63 | + optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) |
| 64 | + fsdp_cfg: FSDPConfig = FSDPConfig() |
| 65 | + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) |
| 66 | + self.worker_cfg: WorkerConfig = WorkerConfig( |
| 67 | + model_cfg=model_cfg, |
| 68 | + optim_cfg=optim_cfg, |
| 69 | + loss_cfg=LossConfig( |
| 70 | + policy_loss_cfg=dict( |
| 71 | + cliprange_high=0.28, |
| 72 | + cliprange_low=0.2, |
| 73 | + loss_type="vanilla", |
| 74 | + ), |
| 75 | + ignore_idx=-100, |
| 76 | + use_kl_loss=True, |
| 77 | + kl_loss_coef=0.001, |
| 78 | + kl_loss_type="low_var_kl", |
| 79 | + mode="eager"), |
| 80 | + lr_cfg=lr_cfg, |
| 81 | + fsdp_cfg=fsdp_cfg, |
| 82 | + load_from=MODEL_PATH, |
| 83 | + sp_size=1, |
| 84 | + pack_max_length=1024, |
| 85 | + ) |
| 86 | + |
| 87 | + @unittest.skipIf(os.environ.get("XTUNER_USE_LMDEPLOY", "0") == "0", "lmdeploy backend is not enabled") |
| 88 | + def test_lmdeploy_update_weight_and_generate(self): |
| 89 | + # init train |
| 90 | + TrainingWorker = ray.remote( |
| 91 | + runtime_env={ |
| 92 | + "env_vars": { |
| 93 | + "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", |
| 94 | + "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", |
| 95 | + } |
| 96 | + }, |
| 97 | + )(BaseTrainingWorker) |
| 98 | + train_workers, _ = AutoAcceleratorWorkers.from_placement_group( |
| 99 | + TrainingWorker, self.worker_cfg, self.pg |
| 100 | + ) |
| 101 | + futures = [ worker.test_all_reduce.remote() for worker in train_workers ] |
| 102 | + ray.get(futures) |
| 103 | + train_controller = TrainingController.remote( |
| 104 | + workers=train_workers, |
| 105 | + ) |
| 106 | + ray.get(train_controller.__ray_ready__.remote()) |
| 107 | + |
| 108 | + # fixed sample params |
| 109 | + sample_params = SampleParams(temperature=0.0, max_tokens=128, top_k=1) |
| 110 | + |
| 111 | + # init rollout_update |
| 112 | + rollout_controller = RolloutController.remote( |
| 113 | + self.rollout_cfg, |
| 114 | + self.pg, |
| 115 | + ) |
| 116 | + info_dict = ray.get(rollout_controller.get_rollout_info.remote()) |
| 117 | + ray.get(train_controller.update_rollout_info.remote(info_dict)) |
| 118 | + |
| 119 | + # update weights |
| 120 | + ray.get(rollout_controller.offload.remote()) |
| 121 | + ray.get(rollout_controller.onload_weights.remote()) |
| 122 | + ray.get(train_controller.offload.remote(["optimizer"])) |
| 123 | + ray.get(train_controller.update_weights.remote()) |
| 124 | + ray.get(train_controller.offload.remote(["model"])) |
| 125 | + ray.get(rollout_controller.onload_kvcache.remote()) |
| 126 | + |
| 127 | + res_update_weight = ray.get(rollout_controller.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) |
| 128 | + ray.get(rollout_controller.shutdown.remote(), timeout=60) |
| 129 | + |
| 130 | + # init rollout_ref |
| 131 | + self.rollout_cfg.skip_load_weights = False |
| 132 | + rollout_controller_ref = RolloutController.remote( |
| 133 | + self.rollout_cfg, |
| 134 | + self.pg, |
| 135 | + ) |
| 136 | + |
| 137 | + res_ref = ray.get(rollout_controller_ref.rollout.remote(prompt=TEST_TEXT_MESSAGES, sample_params=sample_params)) |
| 138 | + ray.get(rollout_controller_ref.shutdown.remote(), timeout=60) |
| 139 | + |
| 140 | + self.assertEqual(res_update_weight.response, res_ref.response) |
| 141 | + |
| 142 | + |
| 143 | +if __name__ == "__main__": |
| 144 | + test_instance = TestUpdateWeight() |
| 145 | + test_instance.setUp() |
| 146 | + try: |
| 147 | + test_instance.test_lmdeploy_update_weight_and_generate() |
| 148 | + finally: |
| 149 | + test_instance.tearDown() |
0 commit comments