From 1c4fee5680f0b6a473d5d12fef4ccca11bc81bc5 Mon Sep 17 00:00:00 2001 From: linjiangxian Date: Fri, 22 May 2026 07:56:34 +0000 Subject: [PATCH 1/2] [FEA] remove unnecessary keys selection and fuse selection to erase --- .../dynamicemb/batched_dynamicemb_function.py | 20 +++++++++---------- .../dynamicemb/embedding_admission.py | 4 ++-- .../dynamicemb/dynamicemb/scored_hashtable.py | 6 ++++++ corelib/dynamicemb/dynamicemb/types.py | 4 +++- 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index 1cc1db420..add88f6f0 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -349,12 +349,11 @@ def _prefetch_cache_path( keys_to_insert_mask = storage_founds.clone() if new_in_miss.any() and admit_strategy is not None: - new_miss_indices = torch.where(new_in_miss)[0] - new_keys_sub = miss_keys[new_miss_indices] - new_tids_sub = miss_tids[new_miss_indices] + new_keys_sub = miss_keys[new_in_miss] + new_tids_sub = miss_tids[new_in_miss] freq_for_admission = ( - miss_lfu_freq[new_miss_indices] if miss_lfu_freq is not None else None + miss_lfu_freq[new_in_miss] if miss_lfu_freq is not None else None ) counters = ( freq_for_admission @@ -366,14 +365,13 @@ def _prefetch_cache_path( if admit_mask.any(): admission_counter.erase( - new_keys_sub[admit_mask], new_tids_sub[admit_mask] + new_keys_sub, new_tids_sub, mask=admit_mask ) - keys_to_insert_mask[new_miss_indices[admit_mask]] = True + keys_to_insert_mask[new_in_miss] = admit_mask - non_admit = ~admit_mask - if non_admit.any(): - non_admitted_miss_pos = new_miss_indices[non_admit] - non_admitted_positions = miss_compact_idx[non_admitted_miss_pos] + non_admit_miss = new_in_miss & ~keys_to_insert_mask + if non_admit_miss.any(): + non_admitted_positions = miss_compact_idx[non_admit_miss] elif new_in_miss.any(): keys_to_insert_mask = torch.ones( h_num_miss, dtype=torch.bool, device=device @@ -586,7 +584,7 @@ def _prefetch_hbm_direct_path( if admit_mask.any(): admission_counter.erase( - missing_keys[admit_mask], missing_table_ids[admit_mask] + missing_keys, missing_table_ids, mask=admit_mask ) non_admit = ~admit_mask diff --git a/corelib/dynamicemb/dynamicemb/embedding_admission.py b/corelib/dynamicemb/dynamicemb/embedding_admission.py index fe0f72410..8a8f2e5fe 100644 --- a/corelib/dynamicemb/dynamicemb/embedding_admission.py +++ b/corelib/dynamicemb/dynamicemb/embedding_admission.py @@ -91,8 +91,8 @@ def add( self.table_.insert(keys, table_ids, self.score_arg_, score_out=scores_out) return scores_out - def erase(self, keys: torch.Tensor, table_ids: torch.Tensor) -> None: - self.table_.erase(keys, table_ids) + def erase(self, keys: torch.Tensor, table_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> None: + self.table_.erase(keys, table_ids, mask=mask) def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: return self.table_.memory_usage(mem_type) diff --git a/corelib/dynamicemb/dynamicemb/scored_hashtable.py b/corelib/dynamicemb/dynamicemb/scored_hashtable.py index 021866c9d..c6ae601dc 100644 --- a/corelib/dynamicemb/dynamicemb/scored_hashtable.py +++ b/corelib/dynamicemb/dynamicemb/scored_hashtable.py @@ -168,11 +168,13 @@ def erase( self, keys: torch.Tensor, table_ids: torch.Tensor, + mask: Optional[torch.Tensor] = None, ) -> None: """ Erase Keys Args: table_ids: int32 tensor of same length as keys, identifying which logical table each key belongs to. + mask: Optional boolean mask. If provided, only masked positions are erased. """ @abc.abstractmethod @@ -787,12 +789,15 @@ def erase( self, keys: torch.Tensor, table_ids: torch.Tensor, + mask: Optional[torch.Tensor] = None, ) -> None: """ Erase Keys Args: table_ids: int32 tensor of same length as keys, identifying which logical table each key belongs to. + mask: Optional boolean mask. If provided, only masked positions are erased. """ + indices = torch.where(mask)[0] if mask is not None else None table_erase( self.table_storage_, self.table_bucket_offsets_, @@ -800,6 +805,7 @@ def erase( self.bucket_sizes, keys, table_ids, + indices, ) def load( diff --git a/corelib/dynamicemb/dynamicemb/types.py b/corelib/dynamicemb/dynamicemb/types.py index a8c2b711f..367952443 100644 --- a/corelib/dynamicemb/dynamicemb/types.py +++ b/corelib/dynamicemb/dynamicemb/types.py @@ -337,13 +337,15 @@ def add( ... @abc.abstractmethod - def erase(self, keys: torch.Tensor, table_ids: torch.Tensor) -> None: + def erase(self, keys: torch.Tensor, table_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> None: """ Erase keys from the `Counter`. Args: keys (torch.Tensor): The input keys to be erased. table_ids (torch.Tensor): The table id for each key. + mask (Optional[torch.Tensor]): Boolean mask of same length as keys. + If provided, only masked positions are erased. """ @abc.abstractmethod From 1012a865e6df1697cab759d3e891a3bb20821543 Mon Sep 17 00:00:00 2001 From: linjiangxian Date: Fri, 22 May 2026 09:22:35 +0000 Subject: [PATCH 2/2] fix --- .../dynamicemb/dynamicemb/batched_dynamicemb_function.py | 6 +++--- corelib/dynamicemb/dynamicemb/embedding_admission.py | 4 ++-- corelib/dynamicemb/dynamicemb/scored_hashtable.py | 6 ------ corelib/dynamicemb/dynamicemb/types.py | 4 +--- 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index add88f6f0..47055bf46 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -365,9 +365,9 @@ def _prefetch_cache_path( if admit_mask.any(): admission_counter.erase( - new_keys_sub, new_tids_sub, mask=admit_mask + new_keys_sub[admit_mask], new_tids_sub[admit_mask] ) - keys_to_insert_mask[new_in_miss] = admit_mask + keys_to_insert_mask[new_in_miss] = admit_mask non_admit_miss = new_in_miss & ~keys_to_insert_mask if non_admit_miss.any(): @@ -584,7 +584,7 @@ def _prefetch_hbm_direct_path( if admit_mask.any(): admission_counter.erase( - missing_keys, missing_table_ids, mask=admit_mask + missing_keys[admit_mask], missing_table_ids[admit_mask] ) non_admit = ~admit_mask diff --git a/corelib/dynamicemb/dynamicemb/embedding_admission.py b/corelib/dynamicemb/dynamicemb/embedding_admission.py index 8a8f2e5fe..fe0f72410 100644 --- a/corelib/dynamicemb/dynamicemb/embedding_admission.py +++ b/corelib/dynamicemb/dynamicemb/embedding_admission.py @@ -91,8 +91,8 @@ def add( self.table_.insert(keys, table_ids, self.score_arg_, score_out=scores_out) return scores_out - def erase(self, keys: torch.Tensor, table_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> None: - self.table_.erase(keys, table_ids, mask=mask) + def erase(self, keys: torch.Tensor, table_ids: torch.Tensor) -> None: + self.table_.erase(keys, table_ids) def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: return self.table_.memory_usage(mem_type) diff --git a/corelib/dynamicemb/dynamicemb/scored_hashtable.py b/corelib/dynamicemb/dynamicemb/scored_hashtable.py index c6ae601dc..021866c9d 100644 --- a/corelib/dynamicemb/dynamicemb/scored_hashtable.py +++ b/corelib/dynamicemb/dynamicemb/scored_hashtable.py @@ -168,13 +168,11 @@ def erase( self, keys: torch.Tensor, table_ids: torch.Tensor, - mask: Optional[torch.Tensor] = None, ) -> None: """ Erase Keys Args: table_ids: int32 tensor of same length as keys, identifying which logical table each key belongs to. - mask: Optional boolean mask. If provided, only masked positions are erased. """ @abc.abstractmethod @@ -789,15 +787,12 @@ def erase( self, keys: torch.Tensor, table_ids: torch.Tensor, - mask: Optional[torch.Tensor] = None, ) -> None: """ Erase Keys Args: table_ids: int32 tensor of same length as keys, identifying which logical table each key belongs to. - mask: Optional boolean mask. If provided, only masked positions are erased. """ - indices = torch.where(mask)[0] if mask is not None else None table_erase( self.table_storage_, self.table_bucket_offsets_, @@ -805,7 +800,6 @@ def erase( self.bucket_sizes, keys, table_ids, - indices, ) def load( diff --git a/corelib/dynamicemb/dynamicemb/types.py b/corelib/dynamicemb/dynamicemb/types.py index 367952443..a8c2b711f 100644 --- a/corelib/dynamicemb/dynamicemb/types.py +++ b/corelib/dynamicemb/dynamicemb/types.py @@ -337,15 +337,13 @@ def add( ... @abc.abstractmethod - def erase(self, keys: torch.Tensor, table_ids: torch.Tensor, mask: Optional[torch.Tensor] = None) -> None: + def erase(self, keys: torch.Tensor, table_ids: torch.Tensor) -> None: """ Erase keys from the `Counter`. Args: keys (torch.Tensor): The input keys to be erased. table_ids (torch.Tensor): The table id for each key. - mask (Optional[torch.Tensor]): Boolean mask of same length as keys. - If provided, only masked positions are erased. """ @abc.abstractmethod