@@ -161,8 +161,7 @@ def _transform_rollout_config_to_server_configs(self):
161161 sglang_server_args .port = self .server_port
162162 sglang_server_args .nccl_port = self .nccl_port
163163 sglang_server_args .dist_init_addr = self .dist_init_addr
164- base_gpu_id_interval = min (num_gpus_per_engine , self .config .gpus_per_node )
165- sglang_server_args .base_gpu_id = (self .rank * base_gpu_id_interval ) % self .config .gpus_per_node
164+ sglang_server_args .base_gpu_id = self .rank % self .config .gpus_per_node
166165 sglang_server_args .gpu_id_step = 1
167166 sglang_server_args .nnodes = max (1 , num_gpus_per_engine // self .config .gpus_per_node )
168167 sglang_server_args .skip_server_warmup = True
@@ -178,8 +177,13 @@ def _transform_rollout_config_to_server_configs(self):
178177 sglang_server_args .log_level = log_level
179178 sglang_server_args .log_level_http = log_level_http
180179 sglang_server_args .enable_deterministic_inference = enable_deterministic_inference
181- sglang_server_args .tp_size = num_gpus_per_engine
182- sglang_server_args .ep_size = num_gpus_per_engine
180+
181+ if self .config .expert_parallel_size > 1 :
182+ sglang_server_args .tp_size = num_gpus_per_engine
183+ sglang_server_args .ep_size = num_gpus_per_engine
184+ else :
185+ sglang_server_args .tp_size = self .config .tensor_parallel_size
186+ sglang_server_args .ep_size = self .config .expert_parallel_size
183187
184188 if grammar_backend is not None :
185189 sglang_server_args .grammar_backend = grammar_backend
0 commit comments