Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/sglang/srt/configs/load_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LoadFormat(str, enum.Enum):
REMOTE_INSTANCE = "remote_instance"
RDMA = "rdma"
LOCAL_CACHED = "local_cached"
SERVERLESS_LLM = "serverless_llm"


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/entrypoints/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def save_remote_model(self, **kwargs):
def save_sharded_model(self, **kwargs):
self.collective_rpc("save_sharded_model", **kwargs)

def save_serverless_llm_state(self, **kwargs):
self.collective_rpc("save_serverless_llm_state", **kwargs)

def score(
self,
query: Optional[Union[str, List[int]]] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,9 @@ def save_sharded_model(self: Scheduler, params):
max_size=params["max_size"],
)

def save_serverless_llm_state(self: Scheduler, params):
self.tp_worker.model_runner.save_serverless_llm_state(path=params["path"])


def _export_static_state(model):
return dict(
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2372,6 +2372,12 @@ def save_sharded_model(
)
ShardedStateLoader.save_model(self.model, path, pattern, max_size)

def save_serverless_llm_state(self, path: str):
from sglang.srt.model_loader.loader import ServerlessLLMModelLoader

logger.info(f"Save ServerlessLLM model state to {path}")
ServerlessLLMModelLoader.save_model(self.model, path)

def update_weights_from_ipc(self, recv_req):
"""Update weights from IPC for checkpoint-engine integration."""
try:
Expand Down
205 changes: 205 additions & 0 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Iterable,
List,
Optional,
Set,
Tuple,
cast,
)
Expand Down Expand Up @@ -284,6 +285,207 @@ def load_model(
raise NotImplementedError


class ServerlessLLMModelLoader(BaseModelLoader):
"""Model loader that loads weights from ServerlessLLM checkpoint store.

This mirrors the vLLM sllm_loader behavior: each TP rank reads its own
shard from `model_path/rank_{tp_rank}` via `sllm_store.torch.load_dict`.
"""

def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for load format {load_config.load_format}"
)

@staticmethod
def _filter_subtensors(tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Filter out view/sub-tensors that share storage with other tensors."""
same_storage_groups: Dict[Any, List[Tuple[str, torch.Tensor]]] = (
collections.defaultdict(list)
)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))

def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()

result: Dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break
if k2 < k:
break
else:
result[k] = t
return result
Comment on lines +307 to +336
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _filter_subtensors method is duplicated here from ShardedStateLoader. To improve maintainability and reduce code duplication, consider extracting this common utility into a shared helper function or a base class method if applicable. 1

Rules References

Footnotes

  1. Avoid code duplication to improve maintainability and reduce the risk of inconsistencies when changes are needed.


def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
from sllm_store.torch import load_dict

from sglang.srt.distributed import get_tensor_model_parallel_rank

tp_rank = get_tensor_model_parallel_rank()
local_model_path = os.path.join(model_config.model_path, f"rank_{tp_rank}")
model_id = self._get_model_id(local_model_path, self._get_storage_path())

target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
model = self._initialize_model_on_cpu(model_config)

device_map = self._build_device_map(target_device)
sllm_state_dict = load_dict(model_id, device_map)

full_state_dict = model.state_dict()
storage_to_keys = self._build_storage_to_keys(full_state_dict)
loaded_params = self._assign_parameters(
model, sllm_state_dict, full_state_dict, storage_to_keys
)
self._assert_all_parameters_loaded(model, loaded_params)
self._run_post_load_hooks(model, target_device, model_config)

return model.eval()

@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
) -> None:
"""Save current TP shard tensors to ServerlessLLM store format."""
from sllm_store.torch import save_dict

from sglang.srt.distributed import get_tensor_model_parallel_rank

rank = get_tensor_model_parallel_rank()
state_dict = ServerlessLLMModelLoader._filter_subtensors(model.state_dict())

# Move to CPU for saving (store accepts CPU tensors)
cpu_state: Dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
cpu_state[key] = tensor.detach().to("cpu").contiguous()

save_path = os.path.join(path, f"rank_{rank}")
os.makedirs(save_path, exist_ok=True)
save_dict(cpu_state, save_path)

@staticmethod
def _get_storage_path() -> str:
return os.getenv("STORAGE_PATH", "./models")

@staticmethod
def _get_model_id(local_model_path: str, storage_path: str) -> str:
normalized_path = os.path.normpath(local_model_path)
normalized_storage = os.path.normpath(storage_path)
if normalized_path.startswith(normalized_storage):
return normalized_path[len(normalized_storage) :].lstrip(os.sep)
return normalized_path

def _initialize_model_on_cpu(self, model_config: ModelConfig) -> nn.Module:
with torch.device("cpu"):
model = _initialize_model(model_config, self.load_config)
return model.eval()

@staticmethod
def _build_device_map(target_device: torch.device) -> Dict[str, int]:
if target_device.type == "cuda":
return {"": torch.cuda.current_device()}
return {"": 0}

@staticmethod
def _build_storage_to_keys(
state_dict: Dict[str, torch.Tensor],
) -> Dict[Tuple[torch.device, int], List[str]]:
storage_to_keys: Dict[Tuple[torch.device, int], List[str]] = (
collections.defaultdict(list)
)
for key, tensor in state_dict.items():
if tensor.numel() > 0:
storage_ptr = (tensor.device, tensor.untyped_storage().data_ptr())
storage_to_keys[storage_ptr].append(key)
return storage_to_keys

@staticmethod
def _assign_parameters(
model: nn.Module,
sllm_state_dict: Dict[str, torch.Tensor],
full_state_dict: Dict[str, torch.Tensor],
storage_to_keys: Dict[Tuple[torch.device, int], List[str]],
) -> Set[str]:
loaded_params: Set[str] = set()
for key, param in model.named_parameters(recurse=True):
tensor = sllm_state_dict.get(key)
if tensor is not None:
param.data = tensor
loaded_params.add(key)
continue

orig_tensor = full_state_dict.get(key)
if orig_tensor is None or orig_tensor.numel() == 0:
continue

storage_ptr = (orig_tensor.device, orig_tensor.untyped_storage().data_ptr())
for saved_key in storage_to_keys.get(storage_ptr, []):
saved_tensor = sllm_state_dict.get(saved_key)
if saved_tensor is None:
continue
param.data = saved_tensor
loaded_params.add(key)
break

return loaded_params

@staticmethod
def _assert_all_parameters_loaded(
model: nn.Module, loaded_params: Set[str]
) -> None:
all_param_names = set(dict(model.named_parameters()).keys())
missing_params = all_param_names - loaded_params
if missing_params:
raise ValueError(
f"Missing parameters {tuple(missing_params)} in loaded state!"
)

def _run_post_load_hooks(
self, model: nn.Module, target_device: torch.device, model_config: ModelConfig
) -> None:
self._apply_quantization_hooks(model, target_device)
self._move_buffers_to_device(model, target_device)
post_load_weights(model, model_config)

@staticmethod
def _apply_quantization_hooks(
model: nn.Module, target_device: torch.device
) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is None:
continue
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)

@staticmethod
def _move_buffers_to_device(model: nn.Module, target_device: torch.device) -> None:
for _, buffer in model.named_buffers(recurse=True):
if buffer.device.type != target_device.type:
buffer.data = buffer.data.to(target_device)


class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""

Expand Down Expand Up @@ -2096,4 +2298,7 @@ def get_model_loader(
if load_config.load_format == LoadFormat.REMOTE_INSTANCE:
return RemoteInstanceModelLoader(load_config)

if load_config.load_format == LoadFormat.SERVERLESS_LLM:
return ServerlessLLMModelLoader(load_config)

return DefaultModelLoader(load_config)
1 change: 1 addition & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
"layered",
"remote",
"remote_instance",
"serverless_llm",
]

QUANTIZATION_CHOICES = [
Expand Down
Loading