Skip to content

Commit 1dcde53

Browse files
[HiCache] support memory_pool_host page head layout (#11644)
1 parent 15bc1f5 commit 1dcde53

File tree

3 files changed

+50
-2
lines changed

3 files changed

+50
-2
lines changed

python/sglang/srt/mem_cache/memory_pool_host.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
transfer_kv_all_layer,
1818
transfer_kv_all_layer_direct_lf_pf,
1919
transfer_kv_all_layer_lf_pf,
20+
transfer_kv_all_layer_lf_ph,
2021
transfer_kv_all_layer_mla,
2122
transfer_kv_all_layer_mla_lf_pf,
2223
transfer_kv_direct,
@@ -25,6 +26,7 @@
2526
transfer_kv_per_layer_mla,
2627
transfer_kv_per_layer_mla_pf_lf,
2728
transfer_kv_per_layer_pf_lf,
29+
transfer_kv_per_layer_ph_lf,
2830
)
2931
if _is_npu:
3032
from sgl_kernel_npu.kvcacheio import TransferDirection, transfer_kv_dim_exchange
@@ -238,6 +240,15 @@ def init_kv_buffer(self):
238240
self.head_num,
239241
self.head_dim,
240242
)
243+
elif self.layout == "page_head":
244+
dims = (
245+
2,
246+
self.page_num,
247+
self.head_num,
248+
self.page_size,
249+
self.layer_num,
250+
self.head_dim,
251+
)
241252
else:
242253
raise ValueError(f"Unsupported layout: {self.layout}")
243254
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
@@ -292,6 +303,20 @@ def load_to_device_per_layer(
292303
item_size=self.token_stride_size,
293304
src_layout_dim=self.layout_dim,
294305
)
306+
elif self.layout == "page_head":
307+
transfer_kv_per_layer_ph_lf(
308+
src_k=self.k_buffer,
309+
dst_k=device_pool.k_buffer[layer_id],
310+
src_v=self.v_buffer,
311+
dst_v=device_pool.v_buffer[layer_id],
312+
src_indices=host_indices,
313+
dst_indices=device_indices,
314+
layer_id=layer_id,
315+
item_size=self.token_stride_size,
316+
src_layout_dim=self.layout_dim,
317+
page_size=self.page_size,
318+
head_num=self.head_num,
319+
)
295320
else:
296321
raise ValueError(f"Unsupported layout: {self.layout}")
297322
elif io_backend == "direct":
@@ -366,6 +391,20 @@ def backup_from_device_all_layer(
366391
dst_layout_dim=self.layout_dim,
367392
num_layers=self.layer_num,
368393
)
394+
elif self.layout == "page_head":
395+
transfer_kv_all_layer_lf_ph(
396+
src_k_layers=device_pool.k_data_ptrs,
397+
dst_k=self.k_buffer,
398+
src_v_layers=device_pool.v_data_ptrs,
399+
dst_v=self.v_buffer,
400+
src_indices=device_indices,
401+
dst_indices=host_indices,
402+
item_size=self.token_stride_size,
403+
dst_layout_dim=self.layout_dim,
404+
num_layers=self.layer_num,
405+
page_size=self.page_size,
406+
head_num=self.head_num,
407+
)
369408
else:
370409
raise ValueError(f"Unsupported layout: {self.layout}")
371410
elif io_backend == "direct":
@@ -409,7 +448,7 @@ def get_data_page(self, index, flat: bool = True) -> torch.Tensor:
409448
data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :]
410449
elif self.layout == "page_first":
411450
data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :]
412-
elif self.layout == "page_first_direct":
451+
elif self.layout in ["page_first_direct", "page_head"]:
413452
real_index = index // self.page_size
414453
data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :]
415454
else:
@@ -450,6 +489,13 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
450489
2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim
451490
)
452491
)
492+
elif self.layout == "page_head":
493+
real_index = index // self.page_size
494+
self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = (
495+
data_page.reshape(
496+
2, 1, self.head_num, self.page_size, self.layer_num, self.head_dim
497+
)
498+
)
453499
else:
454500
raise ValueError(f"Unsupported layout: {self.layout}")
455501

@@ -490,7 +536,7 @@ def get_page_buffer_meta(self, indices):
490536
self.dtype.itemsize * self.page_size * self.head_num * self.head_dim
491537
)
492538
element_size_list = [element_size] * len(ptr_list)
493-
elif self.layout in ["page_first", "page_first_direct"]:
539+
elif self.layout in ["page_first", "page_first_direct", "page_head"]:
494540
for index in range(0, len(indices), self.page_size):
495541
k_ptr = (
496542
kv_buffer_data_ptr

python/sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def register_mem_pool_host(self, mem_pool_host: HostKVCache):
265265
assert self.mem_pool_host.layout in [
266266
"page_first",
267267
"page_first_direct",
268+
"page_head",
268269
], "mooncake store storage backend only support page first or page first direct layout"
269270
buffer = self.mem_pool_host.kv_buffer
270271
try:

python/sglang/srt/server_args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3074,6 +3074,7 @@ def add_cli_args(parser: argparse.ArgumentParser):
30743074
"page_first",
30753075
"page_first_direct",
30763076
"page_first_kv_split",
3077+
"page_head",
30773078
],
30783079
default=ServerArgs.hicache_mem_layout,
30793080
help="The layout of host memory pool for hierarchical cache.",

0 commit comments

Comments
 (0)