Skip to content

Commit ec3175b

Browse files
authored
Merge branch 'main' into refactor-http
2 parents 03fef2c + 1911c71 commit ec3175b

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

xtuner/v1/ray/rollout/sglang.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def _transform_rollout_config_to_server_configs(self):
161161
sglang_server_args.port = self.server_port
162162
sglang_server_args.nccl_port = self.nccl_port
163163
sglang_server_args.dist_init_addr = self.dist_init_addr
164-
base_gpu_id_interval = min(num_gpus_per_engine, self.config.gpus_per_node)
165-
sglang_server_args.base_gpu_id = (self.rank * base_gpu_id_interval) % self.config.gpus_per_node
164+
sglang_server_args.base_gpu_id = self.rank % self.config.gpus_per_node
166165
sglang_server_args.gpu_id_step = 1
167166
sglang_server_args.nnodes = max(1, num_gpus_per_engine // self.config.gpus_per_node)
168167
sglang_server_args.skip_server_warmup = True
@@ -178,8 +177,13 @@ def _transform_rollout_config_to_server_configs(self):
178177
sglang_server_args.log_level = log_level
179178
sglang_server_args.log_level_http = log_level_http
180179
sglang_server_args.enable_deterministic_inference = enable_deterministic_inference
181-
sglang_server_args.tp_size = num_gpus_per_engine
182-
sglang_server_args.ep_size = num_gpus_per_engine
180+
181+
if self.config.expert_parallel_size > 1:
182+
sglang_server_args.tp_size = num_gpus_per_engine
183+
sglang_server_args.ep_size = num_gpus_per_engine
184+
else:
185+
sglang_server_args.tp_size = self.config.tensor_parallel_size
186+
sglang_server_args.ep_size = self.config.expert_parallel_size
183187

184188
if grammar_backend is not None:
185189
sglang_server_args.grammar_backend = grammar_backend

0 commit comments

Comments
 (0)