Skip to content

Refactor/prefetch select cleanup#409

Merged
shijieliu merged 5 commits into
NVIDIA:mainfrom
jiashuy:refactor/prefetch-select-cleanup
Jun 1, 2026
Merged

Refactor/prefetch select cleanup#409
shijieliu merged 5 commits into
NVIDIA:mainfrom
jiashuy:refactor/prefetch-select-cleanup

Conversation

@jiashuy
Copy link
Copy Markdown
Collaborator

@jiashuy jiashuy commented May 29, 2026

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

ShaobinChen-AH and others added 4 commits May 22, 2026 07:56
- _prefetch_hbm_direct_path: hoist admitted_keys/tids/scores/positions
  computation above erase and reuse them, removing the duplicate
  missing_keys[admit_mask]/missing_table_ids[admit_mask] gathers.
- _prefetch_cache_path: simplify non_admit_miss to ~keys_to_insert_mask;
  the & new_in_miss term is a no-op since keys_to_insert_mask keeps
  storage-found positions True, so ~keys_to_insert_mask is already a
  subset of new_in_miss.

No behavior change.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@jiashuy
Copy link
Copy Markdown
Collaborator Author

jiashuy commented May 29, 2026

/build

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 29, 2026

Greptile Summary

This PR refactors two prefetch functions (_prefetch_cache_path and _prefetch_hbm_direct_path) in batched_dynamicemb_function.py to simplify admission-control logic by replacing integer indexing (via torch.where) with direct boolean-mask indexing, and by pre-computing admitted-key variables before the counter-erase call.

  • _prefetch_cache_path: Drops new_miss_indices = torch.where(new_in_miss)[0] and indexes miss_keys/miss_tids/miss_lfu_freq directly with the boolean new_in_miss mask; replaces keys_to_insert_mask[new_miss_indices[admit_mask]] = True with the single assignment keys_to_insert_mask[new_in_miss] = admit_mask; and computes non-admitted positions from ~keys_to_insert_mask instead of a separately indexed ~admit_mask.
  • _prefetch_hbm_direct_path: Pre-computes admitted_keys/admitted_tids/admitted_scores/admitted_unique_positions before the conditional admission_counter.erase, passing the pre-sliced tensors directly to erase instead of re-slicing missing_keys[admit_mask].

Confidence Score: 5/5

Safe to merge — both changed paths produce identical results to the original code.

The refactoring replaces integer-index extraction (torch.where + fancy indexing) with direct boolean-mask indexing throughout both prefetch paths. In _prefetch_cache_path, keys_to_insert_mask[new_in_miss] = admit_mask is equivalent to the old keys_to_insert_mask[new_miss_indices[admit_mask]] = True because the non-admitted positions were already False from the storage_founds.clone() initialization, and the wider ~keys_to_insert_mask for non-admitted-position computation remains limited to non-storage-found, non-admitted slots. In prefetch_hbm_direct_path, pre-computing admitted* before the erase and passing them directly instead of re-slicing is a pure reorganization. No edge-case regressions were found.

No files require special attention.

Important Files Changed

Filename Overview
corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py Boolean-mask indexing replaces integer-index extraction in both prefetch paths; logic is semantically equivalent and the non-admitted-positions computation correctly widens to ~keys_to_insert_mask, which is still limited to non-storage-found, non-admitted keys.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["miss_keys / miss_tids"] --> B{"new_in_miss.any AND admit_strategy?"}
    B -- "No admit_strategy" --> C["keys_to_insert_mask = all-ones"]
    B -- "Yes" --> D["Slice: new_keys_sub = miss_keys[new_in_miss]"]
    D --> E["admission_counter.add → freq"]
    E --> F["admit_strategy.admit → admit_mask"]
    F --> G{"admit_mask.any?"}
    G -- "Yes" --> H["admission_counter.erase admitted keys"]
    H --> I["keys_to_insert_mask[new_in_miss] = admit_mask"]
    G -- "No" --> J["new_in_miss positions remain False"]
    I --> K["non_admit_miss = ~keys_to_insert_mask"]
    J --> K
    K --> L{"non_admit_miss.any?"}
    L -- "Yes" --> M["non_admitted_positions = miss_compact_idx[non_admit_miss]"]
    L -- "No" --> N["non_admitted_positions = None"]
    M --> O{"keys_to_insert_mask.any?"}
    N --> O
    O -- "No" --> P["Set slot_indices = -1, return early"]
    O -- "Yes" --> Q["flagged_compact → insert_keys / insert_tids"]
    Q --> R["cache.insert_and_evict"]
    C --> O
Loading

Reviews (2): Last reviewed commit: "Merge branch 'main' into refactor/prefet..." | Re-trigger Greptile

@JacoCheung
Copy link
Copy Markdown
Collaborator

JacoCheung commented May 29, 2026

Pipeline #53032067 -- failed

Job Status Log
pre_check ✅ success view
train_build ✅ success view
inference_build ✅ success view
tritonserver_build ✅ success view
build_whl ✅ success view
dynamicemb_test_fwd_bwd_8gpus ✅ success view
dynamicemb_test_load_dump_8gpus ✅ success view
unit_test_1gpu_a100 ❌ failed view
unit_test_1gpu_h100 ✅ success view
unit_test_4gpu ✅ success view
unit_test_tp_4gpu ❌ failed view
L20_unit_test_1gpu ✅ success view
inference_unit_test_1gpu ✅ success view
inference_test_1gpu ✅ success view

Result: 12/14 jobs passed

View full pipeline

@jiashuy
Copy link
Copy Markdown
Collaborator Author

jiashuy commented May 29, 2026

TODO: benchmark it

@jiashuy
Copy link
Copy Markdown
Collaborator Author

jiashuy commented Jun 1, 2026

image

@jiashuy
Copy link
Copy Markdown
Collaborator Author

jiashuy commented Jun 1, 2026

The performance is aligned with before, and all tests have passed. It can be merged

@jiashuy jiashuy requested a review from shijieliu June 1, 2026 00:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants