|
17 | 17 | transfer_kv_all_layer, |
18 | 18 | transfer_kv_all_layer_direct_lf_pf, |
19 | 19 | transfer_kv_all_layer_lf_pf, |
| 20 | + transfer_kv_all_layer_lf_ph, |
20 | 21 | transfer_kv_all_layer_mla, |
21 | 22 | transfer_kv_all_layer_mla_lf_pf, |
22 | 23 | transfer_kv_direct, |
|
25 | 26 | transfer_kv_per_layer_mla, |
26 | 27 | transfer_kv_per_layer_mla_pf_lf, |
27 | 28 | transfer_kv_per_layer_pf_lf, |
| 29 | + transfer_kv_per_layer_ph_lf, |
28 | 30 | ) |
29 | 31 | if _is_npu: |
30 | 32 | from sgl_kernel_npu.kvcacheio import TransferDirection, transfer_kv_dim_exchange |
@@ -238,6 +240,15 @@ def init_kv_buffer(self): |
238 | 240 | self.head_num, |
239 | 241 | self.head_dim, |
240 | 242 | ) |
| 243 | + elif self.layout == "page_head": |
| 244 | + dims = ( |
| 245 | + 2, |
| 246 | + self.page_num, |
| 247 | + self.head_num, |
| 248 | + self.page_size, |
| 249 | + self.layer_num, |
| 250 | + self.head_dim, |
| 251 | + ) |
241 | 252 | else: |
242 | 253 | raise ValueError(f"Unsupported layout: {self.layout}") |
243 | 254 | self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize |
@@ -292,6 +303,20 @@ def load_to_device_per_layer( |
292 | 303 | item_size=self.token_stride_size, |
293 | 304 | src_layout_dim=self.layout_dim, |
294 | 305 | ) |
| 306 | + elif self.layout == "page_head": |
| 307 | + transfer_kv_per_layer_ph_lf( |
| 308 | + src_k=self.k_buffer, |
| 309 | + dst_k=device_pool.k_buffer[layer_id], |
| 310 | + src_v=self.v_buffer, |
| 311 | + dst_v=device_pool.v_buffer[layer_id], |
| 312 | + src_indices=host_indices, |
| 313 | + dst_indices=device_indices, |
| 314 | + layer_id=layer_id, |
| 315 | + item_size=self.token_stride_size, |
| 316 | + src_layout_dim=self.layout_dim, |
| 317 | + page_size=self.page_size, |
| 318 | + head_num=self.head_num, |
| 319 | + ) |
295 | 320 | else: |
296 | 321 | raise ValueError(f"Unsupported layout: {self.layout}") |
297 | 322 | elif io_backend == "direct": |
@@ -366,6 +391,20 @@ def backup_from_device_all_layer( |
366 | 391 | dst_layout_dim=self.layout_dim, |
367 | 392 | num_layers=self.layer_num, |
368 | 393 | ) |
| 394 | + elif self.layout == "page_head": |
| 395 | + transfer_kv_all_layer_lf_ph( |
| 396 | + src_k_layers=device_pool.k_data_ptrs, |
| 397 | + dst_k=self.k_buffer, |
| 398 | + src_v_layers=device_pool.v_data_ptrs, |
| 399 | + dst_v=self.v_buffer, |
| 400 | + src_indices=device_indices, |
| 401 | + dst_indices=host_indices, |
| 402 | + item_size=self.token_stride_size, |
| 403 | + dst_layout_dim=self.layout_dim, |
| 404 | + num_layers=self.layer_num, |
| 405 | + page_size=self.page_size, |
| 406 | + head_num=self.head_num, |
| 407 | + ) |
369 | 408 | else: |
370 | 409 | raise ValueError(f"Unsupported layout: {self.layout}") |
371 | 410 | elif io_backend == "direct": |
@@ -409,7 +448,7 @@ def get_data_page(self, index, flat: bool = True) -> torch.Tensor: |
409 | 448 | data_page = self.kv_buffer[:, :, index : index + self.page_size, :, :] |
410 | 449 | elif self.layout == "page_first": |
411 | 450 | data_page = self.kv_buffer[:, index : index + self.page_size, :, :, :] |
412 | | - elif self.layout == "page_first_direct": |
| 451 | + elif self.layout in ["page_first_direct", "page_head"]: |
413 | 452 | real_index = index // self.page_size |
414 | 453 | data_page = self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] |
415 | 454 | else: |
@@ -450,6 +489,13 @@ def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None: |
450 | 489 | 2, 1, self.layer_num, self.page_size, self.head_num, self.head_dim |
451 | 490 | ) |
452 | 491 | ) |
| 492 | + elif self.layout == "page_head": |
| 493 | + real_index = index // self.page_size |
| 494 | + self.kv_buffer[:, real_index : real_index + 1, :, :, :, :] = ( |
| 495 | + data_page.reshape( |
| 496 | + 2, 1, self.head_num, self.page_size, self.layer_num, self.head_dim |
| 497 | + ) |
| 498 | + ) |
453 | 499 | else: |
454 | 500 | raise ValueError(f"Unsupported layout: {self.layout}") |
455 | 501 |
|
@@ -490,7 +536,7 @@ def get_page_buffer_meta(self, indices): |
490 | 536 | self.dtype.itemsize * self.page_size * self.head_num * self.head_dim |
491 | 537 | ) |
492 | 538 | element_size_list = [element_size] * len(ptr_list) |
493 | | - elif self.layout in ["page_first", "page_first_direct"]: |
| 539 | + elif self.layout in ["page_first", "page_first_direct", "page_head"]: |
494 | 540 | for index in range(0, len(indices), self.page_size): |
495 | 541 | k_ptr = ( |
496 | 542 | kv_buffer_data_ptr |
|
0 commit comments