diff --git a/python/sglang/srt/configs/load_config.py b/python/sglang/srt/configs/load_config.py index 042eb322afb..6d8762576da 100644 --- a/python/sglang/srt/configs/load_config.py +++ b/python/sglang/srt/configs/load_config.py @@ -28,6 +28,7 @@ class LoadFormat(str, enum.Enum): REMOTE_INSTANCE = "remote_instance" RDMA = "rdma" LOCAL_CACHED = "local_cached" + SERVERLESS_LLM = "serverless_llm" @dataclass diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 5a0b86f0018..82100ee7d01 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -608,6 +608,9 @@ def save_remote_model(self, **kwargs): def save_sharded_model(self, **kwargs): self.collective_rpc("save_sharded_model", **kwargs) + def save_serverless_llm_state(self, **kwargs): + self.collective_rpc("save_serverless_llm_state", **kwargs) + def score( self, query: Optional[Union[str, List[int]]] = None, @@ -748,7 +751,17 @@ def launch_phase_sigquit_handler(signum, frame): ) kill_process_tree(os.getpid()) - signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) + # Only register signal handler if we're in the main thread + # When running in Ray actors or other non-main threads, signal registration will fail + try: + signal.signal(signal.SIGQUIT, launch_phase_sigquit_handler) + except ValueError: + # signal only works in main thread of the main interpreter + # This is expected when running in Ray actors or subprocesses + logger.warning( + "Cannot register signal handler (not in main thread). " + "Signal handling will be disabled." + ) # Set mp start method mp.set_start_method("spawn", force=True) diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fa0d612e2e9..79ab42f42c2 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -184,6 +184,10 @@ def save_sharded_model(self: Scheduler, params): max_size=params["max_size"], ) + def save_serverless_llm_state(self: Scheduler, **kwargs): + params = kwargs + self.tp_worker.model_runner.save_serverless_llm_state(path=params["path"]) + def _export_static_state(model): return dict( diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 76226cca9fd..24be6a71750 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2372,6 +2372,12 @@ def save_sharded_model( ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def save_serverless_llm_state(self, path: str): + from sglang.srt.model_loader.loader import ServerlessLLMModelLoader + + logger.info(f"Save ServerlessLLM model state to {path}") + ServerlessLLMModelLoader.save_model(self.model, path) + def update_weights_from_ipc(self, recv_req): """Update weights from IPC for checkpoint-engine integration.""" try: diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 8b76a32012b..97c662ad062 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -24,6 +24,7 @@ Iterable, List, Optional, + Set, Tuple, cast, ) @@ -284,6 +285,223 @@ def load_model( raise NotImplementedError +class ServerlessLLMModelLoader(BaseModelLoader): + """Model loader that loads weights from ServerlessLLM checkpoint store. + + This mirrors the vLLM sllm_loader behavior: each TP rank reads its own + shard from `model_path/rank_{tp_rank}` via `sllm_store.torch.load_dict`. + """ + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError( + f"Model loader extra config is not supported for load format {load_config.load_format}" + ) + + def download_model(self, model_config: ModelConfig) -> None: + # Nothing to download; ServerlessLLM store streams tensors + pass + + @staticmethod + def _filter_subtensors(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Filter out view/sub-tensors that share storage with other tensors.""" + same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = ( + collections.defaultdict(list) + ) + for key, tensor in tensors.items(): + if tensor.numel(): + ptr = tensor.untyped_storage().data_ptr() + same_storage_groups[tensor.device, ptr].append((key, tensor)) + + def get_end_ptr(tensor: torch.Tensor) -> int: + return tensor.view(-1)[-1].data_ptr() + tensor.element_size() + + result: Dict[str, torch.Tensor] = {} + for group in same_storage_groups.values(): + for k, t in group: + a, b = t.data_ptr(), get_end_ptr(t) + for k2, t2 in group: + if not t2.is_contiguous(): + continue + a2, b2 = t2.data_ptr(), get_end_ptr(t2) + if a < a2 or b2 < b: + continue + if a2 < a or b < b2 or not t.is_contiguous(): + break + if k2 < k: + break + else: + result[k] = t + return result + + def load_model( + self, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + ) -> nn.Module: + try: + from sllm_store.torch import load_dict + except Exception as e: + logger.error( + f"Failed to import sllm_store.torch.load_dict: {e}, check whether sllm_store is installed correctly." + ) + raise e + + from sglang.srt.distributed import get_tensor_model_parallel_rank + + tp_rank = get_tensor_model_parallel_rank() + local_model_path = os.path.join(model_config.model_path, f"rank_{tp_rank}") + model_id = self._get_model_id(local_model_path, self._get_storage_path()) + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model = self._initialize_model_on_cpu(model_config) + + device_map = self._build_device_map(target_device) + try: + sllm_state_dict = load_dict(model_id, device_map) + except Exception as e: + logger.error( + f"Failed to load model from sllm_store: {e}, check whether sllm_store server is running." + ) + raise e + + full_state_dict = model.state_dict() + storage_to_keys = self._build_storage_to_keys(full_state_dict) + loaded_params = self._assign_parameters( + model, sllm_state_dict, full_state_dict, storage_to_keys + ) + self._assert_all_parameters_loaded(model, loaded_params) + self._run_post_load_hooks(model, target_device, model_config) + + return model.eval() + + @staticmethod + def save_model( + model: torch.nn.Module, + path: str, + ) -> None: + """Save current TP shard tensors to ServerlessLLM store format.""" + from sllm_store.torch import save_dict + + from sglang.srt.distributed import get_tensor_model_parallel_rank + + rank = get_tensor_model_parallel_rank() + state_dict = ServerlessLLMModelLoader._filter_subtensors(model.state_dict()) + + # Move to CPU for saving (store accepts CPU tensors) + cpu_state: Dict[str, torch.Tensor] = {} + for key, tensor in state_dict.items(): + cpu_state[key] = tensor.detach().to("cpu").contiguous() + + save_path = os.path.join(path, f"rank_{rank}") + os.makedirs(save_path, exist_ok=True) + save_dict(cpu_state, save_path) + + @staticmethod + def _get_storage_path() -> str: + return os.getenv("STORAGE_PATH", "./models") + + @staticmethod + def _get_model_id(local_model_path: str, storage_path: str) -> str: + normalized_path = os.path.normpath(local_model_path) + normalized_storage = os.path.normpath(storage_path) + if normalized_path.startswith(normalized_storage): + return normalized_path[len(normalized_storage) :].lstrip(os.sep) + return normalized_path + + def _initialize_model_on_cpu(self, model_config: ModelConfig) -> nn.Module: + with torch.device("cpu"): + model = _initialize_model(model_config, self.load_config) + return model.eval() + + @staticmethod + def _build_device_map(target_device: torch.device) -> Dict[str, int]: + if target_device.type == "cuda": + return {"": torch.cuda.current_device()} + return {"": 0} + + @staticmethod + def _build_storage_to_keys( + state_dict: Dict[str, torch.Tensor], + ) -> Dict[Tuple[torch.device, int], List[str]]: + storage_to_keys: Dict[Tuple[torch.device, int], List[str]] = ( + collections.defaultdict(list) + ) + for key, tensor in state_dict.items(): + if tensor.numel() > 0: + storage_ptr = (tensor.device, tensor.untyped_storage().data_ptr()) + storage_to_keys[storage_ptr].append(key) + return storage_to_keys + + @staticmethod + def _assign_parameters( + model: nn.Module, + sllm_state_dict: Dict[str, torch.Tensor], + full_state_dict: Dict[str, torch.Tensor], + storage_to_keys: Dict[Tuple[torch.device, int], List[str]], + ) -> Set[str]: + loaded_params: Set[str] = set() + for key, param in model.named_parameters(recurse=True): + tensor = sllm_state_dict.get(key) + if tensor is not None: + param.data = tensor + loaded_params.add(key) + continue + + orig_tensor = full_state_dict.get(key) + if orig_tensor is None or orig_tensor.numel() == 0: + continue + + storage_ptr = (orig_tensor.device, orig_tensor.untyped_storage().data_ptr()) + for saved_key in storage_to_keys.get(storage_ptr, []): + saved_tensor = sllm_state_dict.get(saved_key) + if saved_tensor is None: + continue + param.data = saved_tensor + loaded_params.add(key) + break + + return loaded_params + + @staticmethod + def _assert_all_parameters_loaded( + model: nn.Module, loaded_params: Set[str] + ) -> None: + all_param_names = set(dict(model.named_parameters()).keys()) + missing_params = all_param_names - loaded_params + if missing_params: + raise ValueError( + f"Missing parameters {tuple(missing_params)} in loaded state!" + ) + + def _run_post_load_hooks( + self, model: nn.Module, target_device: torch.device, model_config: ModelConfig + ) -> None: + self._apply_quantization_hooks(model, target_device) + self._move_buffers_to_device(model, target_device) + post_load_weights(model, model_config) + + @staticmethod + def _apply_quantization_hooks( + model: nn.Module, target_device: torch.device + ) -> None: + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is None: + continue + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + + @staticmethod + def _move_buffers_to_device(model: nn.Module, target_device: torch.device) -> None: + for _, buffer in model.named_buffers(recurse=True): + if buffer.device.type != target_device.type: + buffer.data = buffer.data.to(target_device) + + class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" @@ -2096,4 +2314,7 @@ def get_model_loader( if load_config.load_format == LoadFormat.REMOTE_INSTANCE: return RemoteInstanceModelLoader(load_config) + if load_config.load_format == LoadFormat.SERVERLESS_LLM: + return ServerlessLLMModelLoader(load_config) + return DefaultModelLoader(load_config) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e798e01e522..5eda8584e18 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -80,6 +80,7 @@ "layered", "remote", "remote_instance", + "serverless_llm", ] QUANTIZATION_CHOICES = [