diff --git a/ucm/integration/vllm/ucm_connector.py b/ucm/integration/vllm/ucm_connector.py index 66216a255..1ac2a9d6a 100644 --- a/ucm/integration/vllm/ucm_connector.py +++ b/ucm/integration/vllm/ucm_connector.py @@ -4,7 +4,7 @@ import pickle import time from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional, Tuple import torch from vllm.config import VllmConfig @@ -20,8 +20,8 @@ from ucm.logger import init_logger from ucm.shared.metrics import ucmmonitor from ucm.shared.metrics.observability import UCMStatsLogger -from ucm.store.factory import UcmConnectorFactory -from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.store.factory_v1 import UcmConnectorFactoryV1 +from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1 from ucm.utils import Config if TYPE_CHECKING: @@ -35,7 +35,7 @@ @dataclass class RequestMeta: - ucm_block_ids: list[str] = field(default_factory=list) + ucm_block_ids: list[bytes] = field(default_factory=list) hbm_hit_block_num: int = 0 # local_computed_block + external_computed_block total_hit_block_num: int = 0 @@ -47,9 +47,9 @@ class RequestMeta: @dataclass class RequestDispatchMeta: load_block_ids: tuple[ - list[str], list[int] + list[bytes], list[int] ] # [0] mean ucm_block_ids, [1] means vllm_block_ids - dump_block_ids: tuple[list[str], list[int]] + dump_block_ids: tuple[list[bytes], list[int]] @dataclass @@ -69,14 +69,14 @@ def __init__(self, vllm_config, rank_id): if RequestHasher._SEED_HASH is None: RequestHasher._SEED_HASH = self("UCM_HASH_SEED") - def __call__(self, input_data) -> int: - if isinstance(input_data, str): - input_bytes = input_data.encode("utf-8") + def __call__(self, input_data) -> bytes: + if isinstance(input_data, bytes): + input_bytes = input_data else: input_bytes = pickle.dumps(input_data, protocol=pickle.HIGHEST_PROTOCOL) h = hashlib.md5(self.meta_bytes + input_bytes) - return int.from_bytes(h.digest(), byteorder="big") + return h.digest() class UCMDirectConnector(KVConnectorBase_V1): @@ -95,6 +95,10 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.block_size = self._vllm_config.cache_config.block_size self.is_mla = self._vllm_config.model_config.is_deepseek_mla self.is_dsa = False + self.num_layers = self._vllm_config.model_config.get_num_layers( + self._vllm_config.parallel_config + ) + self.tp_size = self._vllm_config.parallel_config.tensor_parallel_size self.kv_cache_dtype: torch.dtype = None if current_platform.is_cuda_alike(): @@ -110,21 +114,18 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): if self.local_rank >= 0: self.device = torch_dev.device(f"{dev_name}:{self.local_rank}") - self._layer_offset_cache = {} - - self.store: UcmKVStoreBase - if role == KVConnectorRole.SCHEDULER: - self.request_hasher = RequestHasher(vllm_config, 0) - else: - self.request_hasher = RequestHasher(vllm_config, self.global_rank) + self.k_store: UcmKVStoreBaseV1 + self.v_store: Optional[UcmKVStoreBaseV1] = None # save block info, avoid hash request twice, and track them until request finished self.requests_meta: dict[str, RequestMeta] = {} ucm_config = Config(vllm_config.kv_transfer_config) self.launch_config = ucm_config.get_config() - + logger.info(f"self.launch_config: {self.launch_config}") + self.connector_configs = self.launch_config.get("ucm_connectors", []) + assert len(self.connector_configs) > 0, "no storage connector name in config." self.load_only_first_rank: bool = ( self.launch_config.get("load_only_first_rank", self.is_mla) and self.is_mla ) @@ -134,42 +135,28 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): self.broadcast_fn = self.group_coordinator.broadcast self.broadcast_stream = torch.cuda.Stream() - logger.info(f"self.launch_config: {self.launch_config}") - connector_configs = self.launch_config.get("ucm_connectors", []) - assert len(connector_configs) > 0, "no storage connector name in config." - - name = connector_configs[0].get("ucm_connector_name") - config = connector_configs[0].get("ucm_connector_config") or {} - config["device"] = self.local_rank - config["role"] = "scheduler" if role == KVConnectorRole.SCHEDULER else "worker" - element_size = vllm_config.model_config.dtype.itemsize - single_head_dim = vllm_config.model_config.get_head_size() - num_head_per_tp = vllm_config.model_config.get_num_kv_heads( - vllm_config.parallel_config - ) - total_tp_size = vllm_config.parallel_config.tensor_parallel_size - num_layers = vllm_config.model_config.get_num_layers( - vllm_config.parallel_config - ) - block_size_per_layer = self.block_size * element_size * single_head_dim - config["kv_block_size"] = ( - block_size_per_layer - * num_layers - * (1 if self.is_mla else num_head_per_tp * 2) - ) - config["io_size"] = block_size_per_layer * ( - 1 if self.is_mla else num_head_per_tp - ) - self.store = UcmConnectorFactory.create_connector(name, config) - self.block_data_size = config["kv_block_size"] - - logger.info("init UCConnectorImpl, connector: %s", name) + name = self.connector_configs[0].get("ucm_connector_name") + config = self.connector_configs[0].get("ucm_connector_config") or {} + storage_backends = [ + path for path in config["storage_backends"].split(":") if path + ] + self.k_storage_backends = [os.path.join(p, "k") for p in storage_backends] + self.v_storage_backends = [os.path.join(p, "v") for p in storage_backends] + os.makedirs(self.k_storage_backends[0], exist_ok=True) + os.makedirs(self.v_storage_backends[0], exist_ok=True) logger.info( - "single file size = %d MB, io_size = %d KB,", - config["kv_block_size"] / 1024 / 1024, - config["io_size"] / 1024, + f"Created subdirectories: {self.k_storage_backends}, {self.v_storage_backends}" ) + if role == KVConnectorRole.SCHEDULER: + self.request_hasher = RequestHasher(vllm_config, 0) + # init scheduler-size connector + config["storage_backends"] = ":".join(self.k_storage_backends) + config["role"] = "scheduler" + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + else: + self.request_hasher = RequestHasher(vllm_config, self.global_rank) + self.metrics_config = self.launch_config.get("metrics_config_path", "") if self.metrics_config: self.stats_logger = UCMStatsLogger( @@ -188,7 +175,7 @@ def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): # invlalid block ids due to load errors self._invalid_block_ids: set[int] = set() - def generate_hash(self, block_size: int, request: "Request") -> list[str]: + def generate_hash(self, block_size: int, request: "Request") -> list[bytes]: token_ids = request.all_token_ids ret = [] @@ -205,10 +192,78 @@ def generate_hash(self, block_size: int, request: "Request") -> list[str]: (parent_block_hash_value, block_token_ids_tuple) ) parent_block_hash_value = hash_value - ret.append(str(hash_value)) + ret.append(hash_value) return ret + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + self.kv_caches = kv_caches + sample_kv_layer = next(iter(self.kv_caches.values())) + if self.kv_cache_dtype is None: + self.kv_cache_dtype = sample_kv_layer[0].dtype + if isinstance(sample_kv_layer, torch.Tensor): + logger.info(f"kv cache shape {sample_kv_layer.shape}") + elif isinstance(sample_kv_layer, Tuple): + # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to Tuple + # [(num_blocks, block_size, num_kv_heads, nope_dim/rope_dim)] + # Currently, we treat it as GQA, and use is_dsa to mark it + for i, tensor in enumerate(sample_kv_layer): + logger.info(f"kv cache shape {i}: {tensor.shape}") + if self.is_mla: + self.is_mla = False + self.is_dsa = True + logger.info(f"use mla: {self.is_mla}, use dsa: {self.is_dsa}") + + # init work-side connector + # When handling the GQA case, we will separately dump the k_cache and v_cache. + name = self.connector_configs[0].get("ucm_connector_name") + config = self.connector_configs[0].get("ucm_connector_config") or {} + config["device"] = self.local_rank + config["role"] = "worker" + config["local_rank_size"] = self.tp_size if self.is_mla or self.is_dsa else 1 + if len(sample_kv_layer) == 2: + k_io_size = ( + sample_kv_layer[0][0].numel() * sample_kv_layer[0][0].element_size() + ) + config["io_size"] = k_io_size + config["kv_block_size"] = k_io_size * self.num_layers + config["storage_backends"] = ":".join(self.k_storage_backends) + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, k_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + + v_io_size = ( + sample_kv_layer[1][0].numel() * sample_kv_layer[1][0].element_size() + ) + config["io_size"] = v_io_size + config["kv_block_size"] = v_io_size * self.num_layers + config["storage_backends"] = ":".join(self.v_storage_backends) + self.v_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, v_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + self.block_data_size = (k_io_size + v_io_size) * self.num_layers + else: + k_io_size = sample_kv_layer[0].numel() * sample_kv_layer[0].element_size() + config["io_size"] = k_io_size + config["kv_block_size"] = k_io_size * self.num_layers + config["storage_backends"] = ":".join(self.k_storage_backends) + self.k_store = UcmConnectorFactoryV1.create_connector(name, config) + logger.info("init UCConnectorImpl, k_connector: %s", name) + logger.info( + "single file size = %.3f MB, io_size = %d KB,", + config["kv_block_size"] / 1024 / 1024, + config["io_size"] / 1024, + ) + self.block_data_size = k_io_size * self.num_layers + def get_num_new_matched_tokens( self, request: "Request", @@ -223,7 +278,7 @@ def get_num_new_matched_tokens( if not external_block_ids: return 0, False - lookup_results = self.store.lookup(external_block_ids) + lookup_results = self.k_store.lookup(external_block_ids) external_hit_blocks = 0 for i, hit in enumerate(lookup_results): if not hit: @@ -361,30 +416,6 @@ def build_connector_meta( return UCMConnectorMetadata(requests_dispatch_meta) - def _init_kv_caches_from_forward_context(self, forward_context: "ForwardContext"): - if len(self.kv_caches) > 0: - return - for layer_name in forward_context.no_compile_layers: - attn_layer = forward_context.no_compile_layers[layer_name] - if not hasattr(attn_layer, "kv_cache"): - continue - - if layer_name not in self.kv_caches: - self.kv_caches[layer_name] = attn_layer.kv_cache[ - forward_context.virtual_engine - ] - # Since vllm_ascend >= 0.10.0, the MLA model's tensor shape has changed to - # (2, num_blocks, block_size, num_kv_heads, nope_dim/rope_dim). - # Currently, we treat it as GQA, and use is_dsa to mark it, - # which works but leads to space inefficiency. - # TODO: Optimize this to avoid unnecessary space usage. - sample_kv_layer = next(iter(self.kv_caches.values())) - if self.is_mla and len(sample_kv_layer) == 2: - self.is_mla = False - self.is_dsa = True - if self.kv_cache_dtype is None: - self.kv_cache_dtype = sample_kv_layer[0].dtype - @staticmethod def _extract_layer_index(layer_name: str) -> Optional[int]: """ @@ -395,70 +426,36 @@ def _extract_layer_index(layer_name: str) -> Optional[int]: return int(chunk) return None - def _precompute_layer_offsets(self): - if not self.kv_caches: - return - - sample_kv_layer = next(iter(self.kv_caches.values())) - elem_size = sample_kv_layer[0].element_size() - block_data_size = ( - sample_kv_layer[0].numel() if self.is_mla else sample_kv_layer[0][0].numel() - ) * elem_size - layer_data_size = block_data_size if self.is_mla else block_data_size * 2 - - # precompute all layers offset - for layer_name, _ in self.kv_caches.items(): - layer_id = self._extract_layer_index(layer_name) - assert layer_id is not None - k_offset = layer_data_size * layer_id - v_offset = k_offset + block_data_size if not self.is_mla else 0 - self._layer_offset_cache[layer_name] = (k_offset, v_offset) - - def _get_tensor_and_offset( - self, vllm_block_ids: list[int], kv_layer: torch.Tensor, layer_name: str - ) -> tuple[list[torch.Tensor], list[int]]: + def _get_tensors( + self, vllm_block_id: int + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """ GQA/MHA: one layer shape is (2, num_blocks, block_size, num_kv_heads, head_size) MLA: one layer shape is (num_blocks, block_size, head_size) """ - k_tensors, k_offsets = [], [] - v_tensors, v_offsets = [], [] - k_offset, v_offset = self._layer_offset_cache[layer_name] - - for vllm_block_id in vllm_block_ids: + k_tensors, v_tensors = [], [] + for _, kv_layer in self.kv_caches.items(): k_tensors.append( kv_layer[vllm_block_id] if self.is_mla else kv_layer[0][vllm_block_id] ) - k_offsets.append(k_offset) if not self.is_mla: v_tensors.append(kv_layer[1][vllm_block_id]) - v_offsets.append(v_offset) - return k_tensors + v_tensors, k_offsets + v_offsets - - def _generate_task(self, vllm_block_ids: List[int], ucm_block_ids: List[str]): - if not self._layer_offset_cache: - self._precompute_layer_offsets() - - num_layers = len(self.kv_caches) - num_blocks_per_layer = len(vllm_block_ids) - num_tensors_per_layer = num_blocks_per_layer * (1 if self.is_mla else 2) - dst_tensor_addr = [None] * (num_layers * num_tensors_per_layer) - ucm_offsets = [0] * (num_layers * num_tensors_per_layer) - - idx = 0 - for layer_name, one_layer_kv_cache in self.kv_caches.items(): - tensors, offsets = self._get_tensor_and_offset( - vllm_block_ids, one_layer_kv_cache, layer_name - ) - dst_tensor_addr[idx : idx + len(tensors)] = tensors - ucm_offsets[idx : idx + len(offsets)] = offsets - idx += len(tensors) - - repeat_times = len(self.kv_caches) * (1 if self.is_mla else 2) - ucm_total_block_ids = ucm_block_ids * repeat_times - - assert len(ucm_total_block_ids) == len(ucm_offsets) == len(dst_tensor_addr) - return ucm_total_block_ids, ucm_offsets, dst_tensor_addr + return k_tensors, v_tensors + + def _generate_task( + self, vllm_block_ids: List[int], ucm_block_ids: List[bytes] + ) -> Tuple[ + List[bytes], List[int], List[List[torch.Tensor]], List[List[torch.Tensor]] + ]: + block_ids, shard_indexs, total_k_tensors, total_v_tensors = [], [], [], [] + for i, vllm_block_id in enumerate(vllm_block_ids): + k_tensors, v_tensors = self._get_tensors(vllm_block_id) + block_ids.append(ucm_block_ids[i]) + total_k_tensors.append(k_tensors) + total_v_tensors.append(v_tensors) + shard_indexs.append(0) + + return block_ids, shard_indexs, total_k_tensors, total_v_tensors def _broadcast(self, dst_tensor_addr: list[torch.Tensor]): rec_tensor: torch.Tensor = None @@ -483,9 +480,7 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - self._init_kv_caches_from_forward_context(forward_context) - - request_to_task: dict[str, Optional[Task]] = {} + request_to_task: dict[str, Optional[List[Task]]] = {} req_broadcast_addr = {} is_load = False num_loaded_block = 0 @@ -501,26 +496,34 @@ def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: ucm_block_ids, vllm_block_ids = request.load_block_ids if self.global_rank != 0 and not self.is_mla and not self.is_dsa: for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task( vllm_block_ids, ucm_block_ids ) if self.global_rank == 0 or not self.load_only_first_rank: - request_to_task[request_id] = self.store.load( - ucm_total_block_ids, ucm_offsets, dst_tensor_addr - ) + k_task = self.k_store.load(block_ids, shard_indexs, k_tensors) + request_to_task[request_id] = [k_task] + if v_tensors and self.v_store: + v_task = self.v_store.load(block_ids, shard_indexs, v_tensors) + request_to_task[request_id].append(v_task) else: request_to_task[request_id] = None - req_broadcast_addr[request_id] = dst_tensor_addr + req_broadcast_addr[request_id] = [t for row in k_tensors for t in row] + [ + t for row in v_tensors for t in row + ] - for request_id, task in request_to_task.items(): + for request_id, tasks in request_to_task.items(): # TODO error handling if self.global_rank == 0 or not self.load_only_first_rank: - if self.store.wait(task) != 0: + try: + self.k_store.wait(tasks[0]) + if len(tasks) > 1 and self.v_store: + self.v_store.wait(tasks[1]) + except RuntimeError as e: + logger.error("request {request_id} load kv cache failed.:", e) self._invalid_block_ids.update( metadata.request_meta[request_id].load_block_ids[1] ) - logger.error(f"request {request_id} load kv cache failed.") if self.load_only_first_rank: self._broadcast(req_broadcast_addr[request_id]) load_end_time = time.perf_counter() * 1000 @@ -567,8 +570,7 @@ def wait_for_save(self) -> None: metadata = self._get_connector_metadata() assert isinstance(metadata, UCMConnectorMetadata) - request_to_task: dict[str, Task] = {} - request_to_blocks: dict[str, list[str]] = {} + request_to_task: dict[str, List[Task]] = {} is_save = False num_saved_block = 0 num_saved_request = 0 @@ -583,36 +585,23 @@ def wait_for_save(self) -> None: ucm_block_ids, vllm_block_ids = request.dump_block_ids if self.global_rank != 0: for i, ucm_block_id in enumerate(ucm_block_ids): - ucm_block_ids[i] = str(self.request_hasher(ucm_block_id)) - rets = self.store.create(ucm_block_ids) - end = 0 - for i, ret in enumerate(rets): - if ret != 0: - logger.error( - f"create blocks for {request_id} failed, block index: {i}, ret code: {ret}" - ) - break - end += 1 - - if end == 0: - continue - ucm_block_ids = ucm_block_ids[:end] - vllm_block_ids = vllm_block_ids[:end] - ucm_total_block_ids, ucm_offsets, dst_tensor_addr = self._generate_task( + ucm_block_ids[i] = self.request_hasher(ucm_block_id) + block_ids, shard_indexs, k_tensors, v_tensors = self._generate_task( vllm_block_ids, ucm_block_ids ) - request_to_task[request_id] = self.store.dump( - ucm_total_block_ids, ucm_offsets, dst_tensor_addr - ) - request_to_blocks[request_id] = ucm_block_ids - - for request_id, task in request_to_task.items(): - ucm_block_ids = request_to_blocks[request_id] - if self.store.wait(task) == 0: - self.store.commit(ucm_block_ids, True) - else: - logger.error(f"request {request_id} dump kv cache failed.") - self.store.commit(ucm_block_ids, False) + k_task = self.k_store.dump(block_ids, shard_indexs, k_tensors) + request_to_task[request_id] = [k_task] + if v_tensors and self.v_store: + v_task = self.v_store.dump(block_ids, shard_indexs, v_tensors) + request_to_task[request_id].append(v_task) + + for request_id, tasks in request_to_task.items(): + try: + self.k_store.wait(tasks[0]) + if len(tasks) > 1 and self.v_store: + self.v_store.wait(tasks[1]) + except RuntimeError as e: + logger.error("request {request_id} dump kv cache failed.:", e) save_end_time = time.perf_counter() * 1000 save_speed = ( num_saved_block @@ -793,6 +782,16 @@ def update_state_after_alloc( """ self.connector.update_state_after_alloc(request, blocks, num_external_tokens) + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: kv_caches: + dictionary of layer names, kv cache + """ + self.connector.register_kv_caches(kv_caches) + def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: diff --git a/ucm/store/factory.py b/ucm/store/factory.py index 8b893cda3..ac4c0569a 100644 --- a/ucm/store/factory.py +++ b/ucm/store/factory.py @@ -63,9 +63,6 @@ def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBase: UcmConnectorFactory.register_connector( "UcmNfsStore", "ucm.store.nfsstore.nfsstore_connector", "UcmNfsStore" ) -UcmConnectorFactory.register_connector( - "UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore" -) UcmConnectorFactory.register_connector( "UcmMooncakeStore", "ucm.store.mooncakestore.mooncake_connector", diff --git a/ucm/store/factory_v1.py b/ucm/store/factory_v1.py new file mode 100644 index 000000000..2a502e46e --- /dev/null +++ b/ucm/store/factory_v1.py @@ -0,0 +1,62 @@ +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +import importlib +from typing import Callable + +from ucm.logger import init_logger +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 + +logger = init_logger(__name__) + + +class UcmConnectorFactoryV1: + _registry: dict[str, Callable[[], type[UcmKVStoreBaseV1]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[UcmKVStoreBaseV1]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, connector_name: str, config: dict) -> UcmKVStoreBaseV1: + if connector_name in cls._registry: + connector_cls = cls._registry[connector_name]() + else: + raise ValueError(f"Unsupported connector type: {connector_name}") + assert issubclass(connector_cls, UcmKVStoreBaseV1) + logger.info("Creating connector with name: %s", connector_name) + return connector_cls(config) + + +UcmConnectorFactoryV1.register_connector( + "UcmPcStore", "ucm.store.pcstore.pcstore_connector", "UcmPcStore" +) diff --git a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc index b69ebfe29..993b60299 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_queue.cc +++ b/ucm/store/pcstore/cc/domain/trans/trans_queue.cc @@ -105,9 +105,14 @@ Status TransQueue::Setup(const int32_t deviceId, const size_t streamNumber, cons UC_ERROR("Failed({}) to make host buffer({},{}).", ts.ToString(), blockSize, bufferNumber); return Status::Error(); } - auto success = - this->devPool_.SetWorkerFn([this](auto t, auto) { this->DeviceWorker(std::move(t)); }) - .Run(); + auto success = this->devPool_ + .SetWorkerInitFn([deviceId](auto&) { + Trans::Device device; + auto ts = device.Setup(deviceId); + return ts.Success(); + }) + .SetWorkerFn([this](auto t, auto) { this->DeviceWorker(std::move(t)); }) + .Run(); if (!success) { return Status::Error(); } success = this->filePool_.SetWorkerFn([this](auto t, auto) { this->FileWorker(std::move(t)); }) .SetNWorker(streamNumber) diff --git a/ucm/store/pcstore/cc/domain/trans/trans_task.h b/ucm/store/pcstore/cc/domain/trans/trans_task.h index 8fcb48fba..7078f3d00 100644 --- a/ucm/store/pcstore/cc/domain/trans/trans_task.h +++ b/ucm/store/pcstore/cc/domain/trans/trans_task.h @@ -58,10 +58,10 @@ class TransTask { : id{NextId()}, type{std::move(type)}, startTp{NowTp()}, brief_{std::move(brief)} { } - void Append(const std::string& block, const uintptr_t address) + void Append(const std::string& block, const std::vector& addresses) { - grouped_[block].push_back(address); - number_++; + grouped_[block] = addresses; + number_ += addresses.size(); } auto Str() const noexcept { return fmt::format("{},{},{}", id, brief_, number_); } size_t GroupNumber() const { return grouped_.size(); } diff --git a/ucm/store/pcstore/cpy/pcstore.py.cc b/ucm/store/pcstore/cpy/pcstore.py.cc index e8a11346c..aa55a2390 100644 --- a/ucm/store/pcstore/cpy/pcstore.py.cc +++ b/ucm/store/pcstore/cpy/pcstore.py.cc @@ -71,7 +71,9 @@ class PcStorePy : public PcStore { auto blockId = blockIds.begin(); auto address = addresses.begin(); while ((blockId != blockIds.end()) && (address != addresses.end())) { - task.Append(blockId->cast(), address->cast()); + std::string id = blockId->cast(); + std::vector addrs = address->cast>(); + task.Append(id, addrs); blockId++; address++; } diff --git a/ucm/store/pcstore/pcstore_connector.py b/ucm/store/pcstore/pcstore_connector.py index 13b0b0b6e..e16fef507 100644 --- a/ucm/store/pcstore/pcstore_connector.py +++ b/ucm/store/pcstore/pcstore_connector.py @@ -28,22 +28,22 @@ import torch from ucm.store.pcstore import ucmpcstore -from ucm.store.ucmstore import Task, UcmKVStoreBase +from ucm.store.ucmstore_v1 import Task, UcmKVStoreBaseV1 @dataclass -class NfsTask(Task): +class PcTask(Task): task_id: int -class UcmPcStore(UcmKVStoreBase): +class UcmPcStore(UcmKVStoreBaseV1): def __init__(self, config: Dict): super().__init__(config) self.store = ucmpcstore.PcStore() storage_backends = [ path for path in config["storage_backends"].split(":") if path ] - block_size = int(config["kv_block_size"]) + block_size = config.get("kv_block_size", 33554432) transfer_enable = True if config["role"] == "worker" else False param = ucmpcstore.PcStore.Config(storage_backends, block_size, transfer_enable) if transfer_enable: @@ -52,8 +52,8 @@ def __init__(self, config: Dict): param.transferIoDirect = config.get("use_direct", False) param.transferStreamNumber = config.get("stream_number", 8) param.transferBufferNumber = config.get("buffer_number", 4096) - param.transferLocalRankSize = config.get("local_rank_size", 8) param.transferScatterGatherEnable = config.get("use_scatter_gatter", False) + param.transferLocalRankSize = config.get("local_rank_size", 1) ret = self.store.Setup(param) if ret != 0: msg = f"Failed to initialize ucmpcstore, errcode: {ret}." @@ -62,52 +62,54 @@ def __init__(self, config: Dict): def cc_store(self) -> int: return self.store.CCStoreImpl() - def create(self, block_ids: List[str]) -> List[int]: - return self.store.AllocBatch(block_ids) - - def lookup(self, block_ids: List[str]) -> List[bool]: + def lookup(self, block_ids: List[bytes]) -> List[bool]: return self.store.LookupBatch(block_ids) - def prefetch(self, block_ids: List[str]) -> None: + def prefetch(self, block_ids: List[bytes]) -> None: pass def load( - self, block_ids: List[str], offset: List[int], dst_tensor: List[torch.Tensor] + self, + block_ids: List[bytes], + shard_index: List[int], + dst_tensor: List[List[torch.Tensor]], ) -> Task: - dst_tensor_ptr = [t.data_ptr() for t in dst_tensor] - task_id = self.store.LoadToDevice(block_ids, dst_tensor_ptr) - return NfsTask(task_id=task_id) + dst_tensor_ptrs = [[t.data_ptr() for t in tensors] for tensors in dst_tensor] + task_id = self.store.LoadToDevice(block_ids, dst_tensor_ptrs) + return PcTask(task_id=task_id) def dump( - self, block_ids: List[str], offset: List[int], src_tensor: List[torch.Tensor] + self, + block_ids: List[bytes], + shard_index: List[int], + src_tensor: List[List[torch.Tensor]], ) -> Task: - src_tensor_ptr = [t.data_ptr() for t in src_tensor] - task_id = self.store.DumpFromDevice(block_ids, src_tensor_ptr) - return NfsTask(task_id=task_id) + src_tensor_ptrs = [[t.data_ptr() for t in tensors] for tensors in src_tensor] + task_id = self.store.DumpFromDevice(block_ids, src_tensor_ptrs) + return PcTask(task_id=task_id) def fetch_data( self, - block_ids: List[str], - offset: List[int], - dst_addr: List[int], - size: List[int], + block_ids: List[bytes], + shard_index: List[int], + dst_addr: List[List[int]], ) -> Task: - pass + task_id = self.store.LoadToDevice(block_ids, dst_addr) + return task_id def dump_data( self, - block_ids: List[str], - offset: List[int], - src_addr: List[int], - size: List[int], + block_ids: List[bytes], + shard_index: List[int], + src_addr: List[List[int]], ) -> Task: - pass - - def wait(self, task: Task) -> int: - return self.store.Wait(task.task_id) + task_id = self.store.DumpFromDevice(block_ids, src_addr) + return task_id - def commit(self, block_ids: List[str], is_success: bool = True) -> None: - self.store.CommitBatch(block_ids, is_success) + def wait(self, task: Task) -> None: + ret = self.store.Wait(task.task_id) + if ret != 0: + raise RuntimeError(f"Wait failed for task {task.task_id}, return={ret}") - def check(self, task: Task) -> Tuple[int, bool]: + def check(self, task: Task) -> bool: return self.store.Check(task.task_id) diff --git a/ucm/store/test/e2e/pcstore_embed.py b/ucm/store/test/e2e/pcstore_embed.py index da3e9de86..7138657fa 100644 --- a/ucm/store/test/e2e/pcstore_embed.py +++ b/ucm/store/test/e2e/pcstore_embed.py @@ -30,10 +30,10 @@ import torch from ucm.store.pcstore.pcstore_connector import UcmPcStore -from ucm.store.ucmstore import UcmKVStoreBase +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 -def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBase: +def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBaseV1: config = {} config["storage_backends"] = storage_backends config["kv_block_size"] = block_size @@ -46,7 +46,7 @@ def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreB def make_buffers( block_number, device_id, batch_size, block_dim, block_len, block_layer ): - hashes = [secrets.token_hex(16) for _ in range(block_number)] + hashes = [secrets.token_bytes(16) for _ in range(block_number)] tensors = [ [ torch.rand( @@ -61,44 +61,25 @@ def make_buffers( return hashes, tensors -def embed(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): - results = store.create(hashes) - assert sum(results) == 0 - block_ids = [] - offsets = [] - layers = [] - for hash_id, block in zip(hashes, tensors): - offset = 0 - for layer in block: - block_ids.append(hash_id) - offsets.append(offset) - layers.append(layer) - offset += layer.untyped_storage().size() - task = store.dump(block_ids, offsets, layers) +def embed( + store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] +): + shard_index = [0] * len(hashes) + task = store.dump(hashes, shard_index, tensors) assert task.task_id > 0 - ret = store.wait(task) - assert ret == 0 - store.commit(hashes, True) + store.wait(task) -def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): +def fetch( + store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] +): founds = store.lookup(hashes) for found in founds: assert found - block_ids = [] - offsets = [] - layers = [] - for hash_id, block in zip(hashes, tensors): - offset = 0 - for layer in block: - block_ids.append(hash_id) - offsets.append(offset) - layers.append(layer) - offset += layer.untyped_storage().size() - task = store.load(block_ids, offsets, layers) + shard_index = [0] * len(hashes) + task = store.load(hashes, shard_index, tensors) assert task.task_id > 0 - ret = store.wait(task) - assert ret == 0 + store.wait(task) def cmp_and_print_diff(a, b, rtol=0.0, atol=0.0): @@ -120,7 +101,7 @@ def store_all_hashes(hashes): file_path = os.path.join(current_directory, kvcache_block_hashes_file) with open(file_path, "w", encoding="utf-8") as file: for hs in hashes: - file.write(hs + "\n") + file.write(hs.hex() + "\n") def main(): diff --git a/ucm/store/test/e2e/pcstore_fetch.py b/ucm/store/test/e2e/pcstore_fetch.py index 6299d387d..1d3996113 100644 --- a/ucm/store/test/e2e/pcstore_fetch.py +++ b/ucm/store/test/e2e/pcstore_fetch.py @@ -29,10 +29,10 @@ import torch from ucm.store.pcstore.pcstore_connector import UcmPcStore -from ucm.store.ucmstore import UcmKVStoreBase +from ucm.store.ucmstore_v1 import UcmKVStoreBaseV1 -def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBase: +def setup_store(storage_backends, block_size, device_id, io_size) -> UcmKVStoreBaseV1: config = {} config["storage_backends"] = storage_backends config["kv_block_size"] = block_size @@ -48,7 +48,7 @@ def get_hashes(batch_size, batch_number): file_path = os.path.join(current_directory, kvcache_block_hashes_file) with open(file_path, "r", encoding="utf-8") as file: lines = file.readlines() - total = [line.strip() for line in lines] + total = [bytes.fromhex(line.strip()) for line in lines] hashes = [] for _ in range(batch_number): hashes.extend(random.sample(total, batch_size)) @@ -70,24 +70,16 @@ def make_buffers(device_id, batch_size, block_dim, block_len, block_layer): return tensors -def fetch(store: UcmKVStoreBase, hashes: List[str], tensors: List[List[torch.Tensor]]): +def fetch( + store: UcmKVStoreBaseV1, hashes: List[bytes], tensors: List[List[torch.Tensor]] +): founds = store.lookup(hashes) for found in founds: assert found - block_ids = [] - offsets = [] - layers = [] - for hash_id, block in zip(hashes, tensors): - offset = 0 - for layer in block: - block_ids.append(hash_id) - offsets.append(offset) - layers.append(layer) - offset += layer.untyped_storage().size() - task = store.load(block_ids, offsets, layers) + shard_index = [0] * len(hashes) + task = store.load(hashes, shard_index, tensors) assert task.task_id > 0 - ret = store.wait(task) - assert ret == 0 + store.wait(task) def main(): diff --git a/ucm/store/ucmstore_v1.py b/ucm/store/ucmstore_v1.py new file mode 100644 index 000000000..0fde34ec8 --- /dev/null +++ b/ucm/store/ucmstore_v1.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# +# MIT License +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +from abc import ABC, abstractmethod +from typing import Dict, List, Tuple + +import torch + + +class Task(ABC): + """Asynchronous task handle returned by transfer operations. + + This is an opaque token that can be polled or awaited. + """ + + pass + + +class UcmKVStoreBaseV1(ABC): + """Abstract base class for KV-cache-centric storage backends. + + A concrete storage vendor must implement this interface to participate in + the unified-cache-management (UCM) system. + """ + + def __init__(self, config: Dict[str, object]) -> None: + """Initialize the store with vendor-specific configuration. + + Args: + config: Key-value mapping containing vendor-specific parameters + (e.g., connection string, cache size, compression level). + """ + self.config = config + + @abstractmethod + def cc_store(self) -> int: + """Return a low-level C/C++ pointer to the underlying store. + + Returns: + An opaque ``int`` representing the ``Store*`` instance that can + be passed to native code. + """ + pass + + @abstractmethod + def lookup(self, block_ids: List[bytes]) -> List[bool]: + """Check presence of blocks in external storage. + + Args: + block_ids: List of vLLM block hashes (raw bytes). + + Returns: + A list of booleans, ``True`` if the corresponding block exists in + storage, ``False`` otherwise. The order matches ``block_ids``. + """ + pass + + @abstractmethod + def prefetch(self, block_ids: List[bytes]) -> None: + """Asynchronously prefetch blocks into high-speed cache. + + Args: + block_ids: List of vLLM block hashes to prefetch. + """ + pass + + @abstractmethod + def load( + self, + block_ids: List[bytes], + shard_index: List[int], + dst_tensor: List[List[torch.Tensor]], + ) -> Task: + """Initiate transfer of KV cache from storage to device. + + Args: + block_ids: Hashes of the blocks to load. + shard_index: Shard index for each block. + dst_tensor: Double-list structure where ``dst_tensor[i][j]`` is the + destination PyTorch tensor on device for block ``i``, tensor ``j``. + + Returns: + A ``Task`` handle that can be used to check or wait for completion. + """ + pass + + @abstractmethod + def dump( + self, + block_ids: List[bytes], + shard_index: List[int], + src_tensor: List[List[torch.Tensor]], + ) -> Task: + """Initiate transfer of KV cache from device to storage. + + Args: + block_ids: Hashes of the blocks to write. + shard_index: Shard index for each block. + src_tensor: Double-list structure where ``src_tensor[i][j]`` is the + source PyTorch tensor on device for block ``i``, tensor ``j``. + + Returns: + A ``Task`` handle that can be used to check or wait for completion. + """ + pass + + @abstractmethod + def fetch_data( + self, + block_ids: List[bytes], + shard_index: List[int], + dst_addr: List[List[int]], + ) -> Task: + """Low-level fetch: copy KV data to device pointers. + + Args: + block_ids: Block hashes to load. + shard_index: Shard index for each block. + dst_addr: Double-list of ``int`` pointers (as Python ``int``) to + pre-allocated device buffers. + + Returns: + A ``Task`` handle for the asynchronous copy. + """ + pass + + @abstractmethod + def dump_data( + self, + block_ids: List[bytes], + shard_index: List[int], + src_addr: List[List[int]], + ) -> Task: + """Low-level dump: copy KV data from device pointers. + + Args: + block_ids: Block hashes to store. + shard_index: Shard index for each block. + src_addr: Double-list of ``int`` pointers to device buffers. + + Returns: + A ``Task`` handle for the asynchronous copy. + """ + pass + + @abstractmethod + def wait(self, task: Task) -> None: + """Block until the given transfer task completes. + + Args: + task: Task handle returned by ``load``, ``dump``, ``fetch_data``, + or ``dump_data``. + """ + pass + + @abstractmethod + def check(self, task: Task) -> bool: + """Non-blocking poll for task completion. + + Args: + task: Task handle returned by any transfer method. + + Returns: + ``True`` if the task has finished, ``False`` if still in-flight. + """ + pass