diff --git a/cherry-pick.md b/cherry-pick.md new file mode 100644 index 000000000..cc48bc2e2 --- /dev/null +++ b/cherry-pick.md @@ -0,0 +1,203 @@ +# Cherry-Pick Dependency Analysis: Embedding Admission Strategy + +Goal: get "Embedding admission strategy" (#236) working on top of base commit +`e48294d3` (Add LFU evict strategy, #52). + +All 5 cherry-picks are required with the current local chain. + +## Cherry-Pick Chain + +| # | Local Commit | Upstream | Description | Why needed | +|---|-------------|----------|-------------|------------| +| 1 | `3c79809` | `6f7281a` | Refactor dynamicemb with Cache&Storage (#128) | Introduces `key_value_table.py`, `batched_dynamicemb_tables.py` (V2), `batched_dynamicemb_function.py` — the entire architecture that #236 builds on. These files do not exist at `e48294d3`. | +| 2 | `b818725` | `44525fb` | Support eval mode, move insert to forward (#136) | Introduces `find_and_initialize`, `lookup_forward_dense_eval`, and the `EventQueue` pattern used by `key_value_table.py`. Required to make #128 functional. | +| 3 | `befb8c9` | `d497241` | Fix LFU mode frequency count bug (#176) | **Introduces `types.py`** as a new file in the local cherry-pick. Commit #229 then modifies `types.py` — without this step, #229 will not apply. | +| 4 | `78c8ebc` | `c6adf64` | Counter table interface and ScoredHashTable (#229) | Introduces `scored_hashtable.py`. `embedding_admission.py` (#236) directly imports `ScoreArg`, `ScorePolicy`, `ScoreSpec`, and `get_scored_table` from it. | +| 5 | `68266e9` | `f5b608e` | Embedding admission strategy (#236) | The target feature. | + +## Could Any Be Skipped? + +In the **upstream** repo, `types.py` is first introduced by commit #229 itself +(`c6adf64`, confirmed via `--diff-filter=A`), not by #176. So if cherry-picking +raw upstream commits, the chain might reduce to 3 picks: +`6f7281a` (#128) → `c6adf64` (#229) → `f5b608e` (#236). + +However, the local versions of these commits diverged during conflict resolution +(#136 and #176 were applied with `-X theirs`, modifying the same files). As a +result, the local #229 (`78c8ebc`) modifies `types.py` rather than creating it, +making the local #176 (`befb8c9`) a required intermediate step. + +**Conclusion:** With the current local cherry-picked commits, all 5 are necessary. + +--- + +## Bugs Found in Commit #5 (`68266e9` — Embedding Admission Strategy) + +The local cherry-pick (`68266e9`) was sourced from the jiashuy fork (`c2babbe`) +rather than NVIDIA upstream (`f5b608e`). The two diverge significantly in +`batched_dynamicemb_function.py` and `key_value_table.py`, leaving the admission +strategy non-functional. + +### Bug 1 — `DynamicEmbeddingFunctionV2.forward`: missing `admit_strategy`, `evict_strategy`, `admission_counter` parameters + +`batched_dynamicemb_tables.py` calls `DynamicEmbeddingFunctionV2.apply()` with +these positional arguments: + +| Position | Passed value | Received as (local) | Should be received as | +|----------|-------------|---------------------|-----------------------| +| 13 | `self._admit_strategy` | `frequency_counters` | `admit_strategy` | +| 14 | `self._evict_strategy` | `*args` (silently ignored) | `evict_strategy` | +| 15 | `per_sample_weights` | `*args` (silently ignored) | `frequency_counters` | +| 16 | `self._admission_counter` | `*args` (silently ignored) | `admission_counter` | + +When admission is enabled (`_admit_strategy is not None`), the function +immediately crashes on `frequency_counters.long()` because it received an +`AdmissionStrategy` object instead of a tensor. + +**Fix:** Add `admit_strategy=None`, `evict_strategy=None`, `admission_counter=None` +to the signature (before `frequency_counters`), matching the upstream layout. + +### Bug 2 — `DynamicEmbeddingFunctionV2.forward`: `segmented_unique` receives a `bool` instead of `Optional[EvictStrategy]` + +```python +# Local (wrong): +segmented_unique(indices, indices_table_range, unique_op, + is_lfu_enabled, # bool: False→kLru, True→kLfu by accident + frequency_counts_int64) + +# Upstream (correct): +segmented_unique(indices, indices_table_range, unique_op, + EvictStrategy(evict_strategy.value) if evict_strategy else None, + frequency_counts_int64) +``` + +The C++ binding expects `c10::optional`. Passing `False` maps to +`kLru` (value 0) instead of `nullopt`, causing incorrect eviction-strategy +selection for non-LFU tables. + +**Fix:** Use `evict_strategy` parameter and convert with +`EvictStrategy(evict_strategy.value) if evict_strategy else None`. + +### Bug 3 — `DynamicEmbeddingFunctionV2.forward`: lookup calls missing new parameters + +Both lookup calls omit `evict_strategy`, `admit_strategy`, and `admission_counter`: + +```python +# KeyValueTableCachingFunction.lookup (local, wrong): +KeyValueTableCachingFunction.lookup( + caches[i], storages[i], unique_indices_per_table, unique_embs_per_table, + initializers[i], enable_prefetch, training, + lfu_accumulated_frequency_per_table, # lands in evict_strategy slot → type error +) + +# KeyValueTableFunction.lookup (local, wrong): +KeyValueTableFunction.lookup( + storages[i], unique_indices_per_table, unique_embs_per_table, + initializers[i], training, + lfu_accumulated_frequency_per_table, # lands in evict_strategy slot → type error +) +``` + +Both functions now require `evict_strategy: EvictStrategy` as a positional +argument (added by the local #236), so passing `lfu_accumulated_frequency_per_table` +(a tensor or None) in that slot is a type error at runtime. + +**Fix:** Insert `EvictStrategy(evict_strategy.value) if evict_strategy else None` +before `lfu_accumulated_frequency_per_table`, and append `admit_strategy` and +`admission_counter[i] if admission_counter else None` at the end of each call. + +### Bug 4 — `DynamicEmbeddingFunctionV2.backward`: wrong `None` return count + +```python +# Local: +return (None,) * 14 # reflects old parameter count + +# Should be: +return (None,) * 17 # 3 new forward params (admit_strategy, evict_strategy, + # admission_counter) require 3 more None gradients +``` + +PyTorch autograd requires `backward` to return one gradient per `forward` input. +With 3 extra params the count must increase to 17; returning 14 will cause a +shape mismatch assertion in autograd. + +**Fix:** Change `return (None,) * 14` → `return (None,) * 17`. + +### Bug 5 — `DynamicEmbeddingFunctionV2.backward`: update routed to wrong function + +```python +# Local (wrong): always calls KeyValueTableFunction.update with (cache, storage, ...) +KeyValueTableFunction.update( + caches[i], # lands in storage slot → type error + storages[i], # lands in unique_keys slot → type error + unique_indices_per_table, + unique_embs_per_table, + optimizer, + enable_prefetch, # KeyValueTableFunction.update has no such param +) + +# Upstream (correct): branch on caching flag +if caching: + KeyValueTableCachingFunction.update(caches[i], storages[i], ...) +else: + KeyValueTableFunction.update(storages[i], ...) +``` + +`KeyValueTableFunction.update` takes `(storage, unique_keys, ...)` — no `cache` +parameter. The local backward unconditionally called it as +`KeyValueTableFunction.update(caches[i], storages[i], ...)`, so `caches[i]` +landed in `storage` and `storages[i]` in `unique_keys`, causing a type error on +the first iteration. + +**Fix:** Add `caching = caches[0] is not None` at the top of backward, then +dispatch to `KeyValueTableCachingFunction.update` or `KeyValueTableFunction.update` +accordingly. + +### Bug 6 — `KeyValueTableFunction.lookup`: extra unused `cache` parameter shifts all args + +```python +# Local (wrong): +def lookup( + cache: Optional[Cache], # ← never used in body; shifts everything + storage: Storage, + unique_keys, unique_embs, initializer, training, + evict_strategy, accumulated_frequency=None, + admit_strategy=None, admission_counter=None, +) + +# Upstream (correct): +def lookup( + storage: Storage, + unique_keys, unique_embs, initializer, training, + evict_strategy, accumulated_frequency=None, + admit_strategy=None, admission_counter=None, +) +``` + +The call site (`DynamicEmbeddingFunctionV2.forward`) passes `storages[i]` as the +first argument, which lands in `cache` instead of `storage`, causing a type error +when the body calls `storage.embedding_dim()` on a `KeyValueTable` object that +ended up in `cache`. + +**Fix:** Remove the `cache: Optional[Cache]` parameter from `KeyValueTableFunction.lookup`. + +### Known Issue — `setup.py`: missing `import sys` + +`sys.executable` is referenced at lines 29 and 65 but `import sys` is absent. +This causes `NameError: name 'sys' is not defined` at build time. +This is not a cherry-pick bug (the upstream also lacks it); must be patched manually +before each build. Already documented in `merge.md`. + +--- + +## Fixes Applied (commit 268cf1d) + +| Bug | File | Change | +|-----|------|--------| +| 1 — `forward` missing params | `batched_dynamicemb_function.py` | Added `admit_strategy`, `evict_strategy`, `admission_counter` at positions 13–16; shifted `frequency_counters` to 16 | +| 2 — `segmented_unique` bool arg | `batched_dynamicemb_function.py` | `EvictStrategy(evict_strategy.value) if evict_strategy else None` | +| 3 — lookup calls missing params | `batched_dynamicemb_function.py` | Added `evict_strategy`, `admit_strategy`, `admission_counter[i]` to both caching and non-caching lookup calls | +| 4 — `backward` wrong return count | `batched_dynamicemb_function.py` | `(None,) * 14` → `(None,) * 17` | +| 5 — `backward` wrong update function | `batched_dynamicemb_function.py` | Added `caching` flag; dispatch to `KeyValueTableCachingFunction.update` or `KeyValueTableFunction.update` accordingly | +| 6 — extra `cache` param in `lookup` | `key_value_table.py` | Removed `cache: Optional[Cache]` from `KeyValueTableFunction.lookup` | +| Known — `import sys` missing | `setup.py` | Added `import sys` | diff --git a/corelib/dynamicemb/DynamicEmb_APIs.md b/corelib/dynamicemb/DynamicEmb_APIs.md index 8f15035af..e558f32e6 100644 --- a/corelib/dynamicemb/DynamicEmb_APIs.md +++ b/corelib/dynamicemb/DynamicEmb_APIs.md @@ -11,6 +11,11 @@ - [DynamicEmbTableOptions](#dynamicembtableoptions) - [DynamicEmbDump](#dynamicembdump) - [DynamicEmbLoad](#dynamicembload) +- [incremental_dump](#incremental_dump) +- [get_score](#get_score) +- [set_score](#set_score) +- [Counter](#counter) +- [AdmissionStrategy](#admisssion_strategy) ## DynamicEmbParameterConstraints @@ -379,6 +384,25 @@ Dynamic embedding table parameter class, used to configure the parameters for ea safe_check_mode : DynamicEmbCheckMode Should dynamic embedding table insert safe check be enabled? By default, it is disabled. Please refer to the API documentation for DynamicEmbCheckMode for more information. + global_hbm_for_values : int + Total GPU memory allocated to store embedding + optimizer states, in bytes. Default is 0. + It has different meanings under `caching=True` and `caching=False`. + When `caching=False`, it decides how much GPU memory is in the total memory to store value in a single hybrid table. + When `caching=True`, it decides the table capacity of the GPU table. + external_storage: Storage + The external storage/ParamterServer which inherits the interface of Storage, and can be configured per table. + If not provided, will using KeyValueTable as the Storage. + index_type : Optional[torch.dtype], optional + Index type of sparse features, will be set to DEFAULT_INDEX_TYPE(torch.int64) by default. + admit_strategy : Optional[AdmissionStrategy], optional + Admission strategy for controlling which keys are allowed to enter the embedding table. + If provided, only keys that meet the strategy's criteria will be inserted into the table. + Keys that don't meet the criteria will still be initialized and used in the forward pass, + but won't be stored in the table. Default is None (all keys are admitted). + admission_counter : Optional[Counter], optional + Counter for tracking the number of keys that have been admitted to the embedding table. + If provided, the counter will be used to track the number of keys that have been admitted to the embedding table. + Default is None (no counter is used). Notes ----- @@ -390,6 +414,13 @@ Dynamic embedding table parameter class, used to configure the parameters for ea default_factory=DynamicEmbInitializerArgs ) score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.TIMESTAMP + bucket_capacity: int = 128 + safe_check_mode: DynamicEmbCheckMode = DynamicEmbCheckMode.IGNORE + global_hbm_for_values: int = 0 # in bytes + external_storage: Storage = None + index_type: Optional[torch.dtype] = None + admit_strategy: Optional[AdmissionStrategy] = None + admission_counter: Optional[Counter] = None ``` If using `DynamicEmbInitializerMode.UNIFORM`, `DynamicEmbeddingShardingPlanner` will set the `initializer_args.upper` and `initializer_args.lower` to +/- sqrt(1 / eb_config.num_embeddings) @@ -579,4 +610,171 @@ Setting the environment variable DYNAMICEMB_CSTM_SCORE_CHECK to 0 will not throw Returns: None. """ - ``` \ No newline at end of file + ``` + +## Counter + +**dynamicemb** provides an interface to the Counter which will be used in the embedding admission, and the users can customize the counter implementation by inherit the class `Counter`. + + +```python +class Counter(abc.ABC): + """ + Interface of a counter table which maps a key to a counter. + """ + + @abc.abstractmethod + def add( + self, keys: torch.Tensor, frequencies: torch.Tensor, inplace: bool + ) -> torch.Tensor: + """ + Add keys with frequencies to the `Counter` and get accumulated counter of each key. + For not existed keys, the frequencies will be assigned directly. + For existing keys, the frequencies will be accumulated. + Args: + keys (torch.Tensor): The input keys, should be unique keys. + frequencies (torch.Tensor): The input frequencies, serve as initial or incremental values of frequencies' states. + inplace: If true then store the accumulated_frequencies to counter. + Returns: + accumulated_frequencies (torch.Tensor): the frequencies' state in the `Counter` for the input keys. + """ + accumulated_frequencies: torch.Tensor + return accumulated_frequencies + + @abc.abstractmethod + def erase(self, keys) -> None: + """ + Erase keys form the `Counter`. + Args: + keys (torch.Tensor): The input keys to be erased. + """ + + @abc.abstractmethod + def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + Args: + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + + @abc.abstractmethod + def load(self, key_file, counter_file) -> None: + """ + Load keys and frequencies from input file path. + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ + + @abc.abstractmethod + def dump(self, key_file, counter_file) -> None: + """ + Dump keys and frequencies to output file path. + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ +``` + +**dynamicemb** also provides a built-in counter implementation named `KVCounter`. +There is as capacity limit of `KVCounter` which is bucketized, and the key with the smallest frequency will be evicted from the bucket for a new key if the bucket is full. + +```python + +class KVCounter(Counter): + """ + Interface of a counter table which maps a key to a counter. + """ + + def __init__( + self, + capacity: int, + bucket_capacity: Optional[int] = 128, + key_type: Optional[torch.dtype] = torch.int64, + device: torch.device = None, + ) +``` + +## AdmissionStrategy + +**AdmissionStrategy** is another component for implementing embedding admission. +The keys not in the dynamic embedding table, will first be passed to the `Counter`, after get the accumulated frequencies among the previous training process, the `AdmissionStrategy` will determine which keys will be admitted into the dynamic embedding table. + +```python +class AdmissionStrategy(abc.ABC): + @abc.abstractmethod + def admit( + self, + keys: torch.Tensor, + frequencies: torch.Tensor, + ) -> torch.Tensor: + """ + Admit keys with frequencies >= threshold. + """ + + @abc.abstractmethod + def get_initializer_args(self) -> Optional[DynamicEmbInitializerArgs]: + """ + Get the initializer args for keys that are not admitted. + """ +``` + +**dynamicemn** provides built-in `FrequencyAdmissionStrategy`, which will return keys whose frequencies are not less than the threshold. + +```python +class FrequencyAdmissionStrategy(AdmissionStrategy): + """ + Frequency-based admission strategy. + Only admits keys whose frequency (score) meets or exceeds a threshold. + Parameters + ---------- + threshold : int + Minimum frequency threshold for admission. Keys with frequency >= threshold + will be admitted into the embedding table. + initializer_args: Optional[DynamicEmbInitializerArgs] + Initializer arguments which determine how to initialize the embedding if the key is not admitted. + """ + + def __init__( + self, + threshold: int, + initializer_args: Optional[DynamicEmbInitializerArgs] = None, + ) +``` + +# Functionality and User interface + +## Distributed embedding training and evaluation + +Once the model containing `EmbeddingCollection` is built and initialized through `DistributedModelParallel`, it can be trained and evaluated on each GPU like a single GPU, with torchrec completing communication between different GPUs. + +The switching between training and evaluation modes should be consistent with `nn.Module`, while `training` in [DynamicEmbTableOptions](../dynamicemb/dynamicemb_config.py) is used to guide whether to allocate memory to optimizer states when builds the table. + +Due to limited resources, the dynamic embedding table does not pre allocate memory for all keys. If a key appears for the first time during training, it will be initialized immediately during the training process. Please see `initializer_args` and `eval_initializer_args` in `DynamicEmbTableOptions` for more information. + +## Automatic eviction + +The size of the table is finite, but the set of keys during training may be infinite. dynamicemb provides the function of automatic eviction, which constrains the size of tables reasonably when there is no available space. See `score_strategy` and `bucket_capacity` for more information. + +## Caching and prefetch + +dynamicemb supports caching hot embeddings on GPU memory, and you can prefetch keys from host to device like torchrec(document and example is waiting to append, and now please see `test_prefetch_flush_in_cache` in [test prefetch](./test/test_batched_dynamic_embedding_tables_v2.py)). + +## External storage + +dynamicemb supports external storage once `external_storage` in `DynamicEmbTableOptions` inherits the `Storage` interface under [types.py](../dynamicemb/types.py). +Refer to demo `PyDictStorage` in [uint test](../test/test_batched_dynamic_embedding_tables_v2.py) for detailed usage. + + +## Table expansion + +Users can specify the initial capacity of a table on a single GPU. When the specified load factor is reached, the capacity of the table will double until the limit is reached. See `init_capacity`, `max_load_factor`, `max_capacity` in `DynamicEmbTableOptions` for more information. + + +## Dump/Load and Incremental dump + +Dump/Load and incremental dump is different from general module in PyTorch, because dynamicemb's underlying implementation is a hash table instead of a dense `torch.Tensor`. + +So dynamicemb provides dedicated interface to load/save models' states, and provide conditional dump to support online training. + +Please see `DynamicEmbDump`, `DynamicEmbLoad`, `incremental_dump` in [APIs Doc](../DynamicEmb_APIs.md) for more information. diff --git a/corelib/dynamicemb/STYLE_GUIDE.md b/corelib/dynamicemb/STYLE_GUIDE.md index 7096eb56f..f2b0aef43 100644 --- a/corelib/dynamicemb/STYLE_GUIDE.md +++ b/corelib/dynamicemb/STYLE_GUIDE.md @@ -22,7 +22,7 @@ sudo apt install clang-format-18 format all with: ```bash -find ./ \( -path ./HierarchicalKV -prune \) -o \( -iname *.h -o -iname *.cpp -o -iname *.cc -o -iname *.cu -o -iname *.cuh \) -print | xargs clang-format-18 -i --style=file +find ./src -type f \( -name "*.cu" -o -name "*.cuh" -o -name "*.cpp" -o -name "*.h" \) -exec clang-format-18 -i {} \; ``` diff --git a/corelib/dynamicemb/benchmark/README.md b/corelib/dynamicemb/benchmark/README.md index 8cdf7de33..19123cdfd 100644 --- a/corelib/dynamicemb/benchmark/README.md +++ b/corelib/dynamicemb/benchmark/README.md @@ -59,13 +59,13 @@ When generating indices, we utilize an extremely large range(2^63), so that most The overhead(ms) on H100 PCIe: -| use_index_dedup | batch_size | num_embeddings_per_feature | hbm_for_embeddings | optimizer_type | forward_overhead | backward_overhead | totoal_overhead | -|-----------------|------------|----------------------------|--------------------|----------------|------------------|-------------------|-----------------| -| TRUE | 65536 | 8388608 | 4 | sgd | 0.54184 | 0.363057 | 0.904897 | -| TRUE | 65536 | 8388608 | 4 | adam | 0.601176 | 0.477679 | 1.078855 | -| TRUE | 65536 | 67108864 | 4 | sgd | 2.746669 | 4.148325 | 6.894995 | -| TRUE | 65536 | 67108864 | 4 | adam | 3.226324 | 11.76063 | 14.98695 | -| TRUE | 1048576 | 8388608 | 4 | sgd | 5.158324 | 3.05149 | 8.209814 | -| TRUE | 1048576 | 8388608 | 4 | adam | 5.170962 | 7.844773 | 13.01574 | -| TRUE | 1048576 | 67108864 | 4 | sgd | 50.48192 | 56.61244 | 107.0944 | -| TRUE | 1048576 | 67108864 | 4 | adam | 74.15156 | 186.0786 | 260.2301 | \ No newline at end of file +| use_index_dedup | batch_size | num_embeddings_per_feature | hbm_for_embeddings | optimizer_type | feature_distribution-alpha | embedding_dim | num_iterations | cache_algorithm | eval(torchrec) | forward(torchrec) | backward(torchrec) | train(torchrec) | eval(dynamicemb) | forward(dynamicemb) | backward(dynamicemb) | train(dynamicemb) | +| --------------- | ---------- | -------------------------- | ------------------ | -------------- | -------------------------- | ------------- | -------------- | --------------- | -------------- | ----------------- | ------------------ | --------------- | ---------------- | ------------------- | -------------------- | ----------------- | +| False | 65536 | 8388608 | 4294967296 | sgd | pow-law-1.05 | 128 | 100 | lru | 0.4965 | 0.4972 | 0.4929 | 0.9901 | 0.0687 | 0.1951 | 0.4059 | 0.5999 | +| False | 65536 | 8388608 | 12884901888 | adam | pow-law-1.05 | 128 | 100 | lru | 0.5000 | 0.4999 | 1.1617 | 1.6616 | 0.0691 | 0.2001 | 0.4339 | 0.6347 | +| False | 65536 | 67108864 | 4294967296 | sgd | pow-law-1.05 | 128 | 100 | lru | 0.5124 | 0.5124 | 0.5376 | 1.0499 | 1.0508 | 1.1495 | 1.282 | 2.4302 | +| False | 65536 | 67108864 | 12884901888 | adam | pow-law-1.05 | 128 | 100 | lru | 0.5158 | 0.5157 | 1.2876 | 1.8033 | 1.0916 | 1.2015 | 1.4509 | 2.6543 | +| False | 1048576 | 8388608 | 4294967296 | sgd | pow-law-1.05 | 128 | 100 | lru | 7.5263 | 7.5274 | 3.6960 | 11.2234 | 0.6011 | 0.8402 | 1.6120 | 2.4558 | +| False | 1048576 | 8388608 | 12884901888 | adam | pow-law-1.05 | 128 | 100 | lru | 7.5300 | 7.5305 | 10.2640 | 17.7945 | 0.6012 | 0.8596 | 1.8197 | 2.6794 | +| False | 1048576 | 67108864 | 4294967296 | sgd | pow-law-1.05 | 128 | 100 | lru | 7.8093 | 7.8095 | 4.4519 | 12.2614 | 15.0906 | 10.8440 | 11.8741 | 22.7194 | +| False | 1048576 | 67108864 | 12884901888 | adam | pow-law-1.05 | 128 | 100 | lru | 7.8124 | 7.8129 | 12.5192 | 20.3321 | 15.5863 | 11.2428 | 12.6806 | 23.9257 | \ No newline at end of file diff --git a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py index 7d13baa03..626a47cd0 100644 --- a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.py @@ -16,10 +16,7 @@ import argparse import json import os -import random -import sys -import time -from typing import List +from typing import cast import torch import torch.distributed as dist @@ -31,9 +28,561 @@ DynamicEmbTableOptions, EmbOptimType, ) -from dynamicemb.batched_dynamicemb_tables import BatchedDynamicEmbeddingTables +from dynamicemb.batched_dynamicemb_tables import ( + BatchedDynamicEmbeddingTables, + BatchedDynamicEmbeddingTablesV2, +) +from dynamicemb.key_value_table import KeyValueTable +from dynamicemb_extensions import DynamicEmbTable, insert_or_assign +from fbgemm_gpu.runtime_monitor import StdLogStatsReporterConfig +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, + CacheAlgorithm, + EmbeddingLocation, + PoolingMode, + RecordCacheMetrics, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + SplitTableBatchedEmbeddingBagsCodegen, +) from torch.distributed.elastic.multiprocessing.errors import record +report_interval = 10 +warmup_repeat = 100 + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def get_emb_precision(precision_str): + if precision_str == "fp32": + return torch.float32 + elif precision_str == "fp16": + return torch.float16 + elif precision_str == "bf16": + return torch.bfloat16 + else: + raise ValueError("unknown embedding precision type") + + +def get_fbgemm_precision(precision_str): + if precision_str == "fp32": + return SparseType.FP32 + elif precision_str == "fp16": + return SparseType.FP16 + elif precision_str == "bf16": + return SparseType.BF16 + else: + raise ValueError("unknown embedding precision type") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark BatchedDynamicEmbeddingTables in dynamicemb." + ) + + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="batch size used for training", + ) + parser.add_argument( + "--num_embeddings_per_feature", + type=str, + default="1", + help="Comma separated max_ind_size(MB) per sparse feature. The number of embeddings in each embedding table.", + ) + parser.add_argument( + "--num_iterations", + type=int, + default=100, + help="number of iterations", + ) + parser.add_argument( + "--hbm_for_embeddings", + type=str, + default="1", + help="HBM reserved for values in GB.", + ) + parser.add_argument( + "--optimizer_type", + type=str, + default="adam", + choices=["sgd", "adam", "exact_adagrad", "exact_row_wise_adagrad"], + help="optimizer type.", + ) + parser.add_argument( + "--feature_distribution", + type=str, + default="random", + choices=["random", "pow-law"], + help="Distribution of sparse features.", + ) + parser.add_argument( + "--alpha", type=float, default=1.05, help="Exponent of power-law distribution." + ) + parser.add_argument( + "--seed", type=int, default=42, help="random seed used for initialization" + ) + parser.add_argument( + "--use_index_dedup", + action="store_true", + help="Use index deduplication, using to select the codepath.", + ) + parser.add_argument("--caching", action="store_true") + parser.add_argument("--cache_metrics", action="store_true") + parser.add_argument( + "--embedding_dim", type=int, default=128, help="Size of each embedding." + ) + parser.add_argument( + "--emb_precision", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "fp8"], + ) + parser.add_argument( + "--output_dtype", + type=str, + default="fp32", + choices=["fp32", "fp16", "bf16", "fp8"], + ) + parser.add_argument( + "--cache_algorithm", + type=str, + default="lru", + choices=["lru", "lfu"], + ) + parser.add_argument( + "--gpu_ratio", + type=float, + default=0.125, + help="cache how many embeddings to HBM", + ) + parser.add_argument( + "--table_version", + type=int, + default=1, + help="Table Version", + ) + + parser.add_argument("--learning_rate", type=float, default=0.1) + parser.add_argument("--eps", type=float, default=1e-3, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=0.9, help="beta1.") + parser.add_argument("--beta2", type=float, default=0.999, help="beta1.") + parser.add_argument("--weight_decay", type=float, default=0, help="weight_decay.") + + args = parser.parse_args() + args.num_embeddings_per_feature = [ + int(v) * 1024 * 1024 for v in args.num_embeddings_per_feature.split(",") + ] + args.num_embedding_table = len(args.num_embeddings_per_feature) + args.hbm_for_embeddings = [ + int(v) * (1024**3) for v in args.hbm_for_embeddings.split(",") + ] + + return args + + +def table_idx_to_name(i): + return f"t_{i}" + + +def feature_idx_to_name(i): + return f"cate_{i}" + + +def get_dynamicemb_optimizer(optimizer_type): + if optimizer_type == "sgd": + return EmbOptimType.EXACT_SGD + elif optimizer_type == "exact_sgd": + return EmbOptimType.EXACT_SGD + elif optimizer_type == "adam": + return EmbOptimType.ADAM + elif optimizer_type == "exact_adagrad": + return EmbOptimType.EXACT_ADAGRAD + elif optimizer_type == "exact_row_wise_adagrad": + return EmbOptimType.EXACT_ROWWISE_ADAGRAD + else: + raise ValueError("unknown optimizer type") + + +def get_fbgemm_optimizer(optimizer_type): + if optimizer_type == "sgd": + return OptimType.EXACT_SGD + elif optimizer_type == "exact_sgd": + return OptimType.EXACT_SGD + elif optimizer_type == "adam": + return OptimType.ADAM + elif optimizer_type == "exact_adagrad": + return OptimType.EXACT_ADAGRAD + elif optimizer_type == "exact_row_wise_adagrad": + return OptimType.EXACT_ROWWISE_ADAGRAD + else: + raise ValueError("unknown optimizer type") + + +def generate_sequence_sparse_feature(args, device): + feature_names = [ + feature_idx_to_name(feature_idx) + for feature_idx in range(args.num_embedding_table) + ] + if args.feature_distribution == "random": + res = [] + for x in range(args.num_iterations): + indices_list = [] + lengths_list = [] + for i in range(args.num_embedding_table): + indices_list.append( + torch.randint(low=0, high=(2**63) - 1, size=(args.batch_size,)) + ) + indices = torch.cat(indices_list, dim=0) + indices = indices.to(dtype=torch.int64, device="cuda") + lengths_list.extend([1] * args.batch_size * args.num_embedding_table) + lengths = torch.tensor(lengths_list, dtype=torch.int64).cuda() + + res.append( + torchrec.KeyedJaggedTensor( + keys=feature_names, + values=indices, + lengths=lengths, + ) + ) + return res + elif args.feature_distribution == "pow-law": + assert args.num_embedding_table == 1 + from dataset_generator import gen_jagged_key + + res = [ + gen_jagged_key( + args.batch_size, + 1, + args.alpha, + args.num_embeddings_per_feature[0], + device, + feature_names, + ) + for i in range(args.num_iterations) + ] + return res + elif args.feature_distribution == "zipf": + assert args.num_embedding_table == 1 + from dataset_generator import zipf + + total_indices = zipf( + min_val=0, + max_val=args.num_embeddings_per_feature[0], + exponent=args.alpha, + size=args.batch_size * args.num_iterations, + device=device, + ) + total_indices = total_indices.to(dtype=torch.int64, device="cuda") + res = [] + for x in range(args.num_iterations): + indices = total_indices[x * args.batch_size : (x + 1) * args.batch_size] + lengths_list = [] + lengths_list.extend([1] * args.batch_size * args.num_embedding_table) + lengths = torch.tensor(lengths_list, dtype=torch.int64).cuda() + feature_names = [ + feature_idx_to_name(feature_idx) + for feature_idx in range(args.num_embedding_table) + ] + + res.append( + torchrec.KeyedJaggedTensor( + keys=feature_names, + values=indices, + lengths=lengths, + ) + ) + return res + else: + raise ValueError( + f"Not support distribution {args.feature_distribution} of sparse features." + ) + + +class TableShim: + def __init__(self, table): + if isinstance(table, DynamicEmbTable): + self.table = cast(DynamicEmbTable, table) + elif isinstance(table, KeyValueTable): + self.table = cast(KeyValueTable, table) + else: + raise ValueError("Not support table type") + + def optim_states_dim(self) -> int: + if isinstance(self.table, DynamicEmbTable): + return self.table.optstate_dim() + else: + return self.table.value_dim() - self.table.embedding_dim() + + def init_optim_state(self) -> float: + if isinstance(self.table, DynamicEmbTable): + return self.table.get_initial_optstate() + else: + return self.table.init_optimizer_state() + + def insert( + self, + n, + unique_indices, + unique_values, + scores, + ) -> None: + if isinstance(self.table, DynamicEmbTable): + insert_or_assign(self.table, n, unique_indices, unique_values, scores) + else: + # self.table.set_score(scores[0].item()) + self.table.insert(unique_indices, unique_values, scores) + + +def create_dynamic_embedding_tables(args, device): + table_options = [] + table_num = args.num_embedding_table + for i in range(table_num): + if args.table_version == 1: + TableModule = BatchedDynamicEmbeddingTables + table_options.append( + DynamicEmbTableOptions( + index_type=torch.int64, + embedding_dtype=get_emb_precision(args.emb_precision), + dim=args.embedding_dim, + max_capacity=args.num_embeddings_per_feature[i], + local_hbm_for_values=args.hbm_for_embeddings[i], + bucket_capacity=128, + initializer_args=DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.NORMAL, + ), + score_strategy=DynamicEmbScoreStrategy.LFU + if args.cache_algorithm == "lfu" + else DynamicEmbScoreStrategy.TIMESTAMP, + ) + ) + elif args.table_version == 2: + TableModule = BatchedDynamicEmbeddingTablesV2 + table_options.append( + DynamicEmbTableOptions( + index_type=torch.int64, + embedding_dtype=get_emb_precision(args.emb_precision), + dim=args.embedding_dim, + max_capacity=args.num_embeddings_per_feature[i], + local_hbm_for_values=args.hbm_for_embeddings[i], + bucket_capacity=128, + initializer_args=DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.NORMAL, + ), + score_strategy=DynamicEmbScoreStrategy.LFU + if args.cache_algorithm == "lfu" + else DynamicEmbScoreStrategy.TIMESTAMP, + caching=args.caching, + ) + ) + else: + raise ValueError("Not support table version") + + var = TableModule( + table_options=table_options, + table_names=[table_idx_to_name(i) for i in range(table_num)], + use_index_dedup=args.use_index_dedup, + pooling_mode=DynamicEmbPoolingMode.NONE, + output_dtype=get_emb_precision(args.output_dtype), + device=device, + optimizer=get_dynamicemb_optimizer(args.optimizer_type), + learning_rate=args.learning_rate, + eps=args.eps, + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + ) + + for table_id in range(table_num): + cur_table = TableShim(var.tables[table_id]) + + num_embeddings = args.num_embeddings_per_feature[table_id] + fill_batch = 1024 * 1024 + i = 0 + while i < num_embeddings: + start = i + end = min(i + fill_batch, num_embeddings) + i += fill_batch + unique_indices = torch.arange(start, end, device=device, dtype=torch.int64) + unique_values = torch.rand( + unique_indices.numel(), + args.embedding_dim, + device=device, + dtype=torch.float32, + ) + + optstate_dim = cur_table.optim_states_dim() + initial_accumulator = cur_table.init_optim_state() + optstate = ( + torch.rand( + unique_values.size(0), + optstate_dim, + dtype=unique_values.dtype, + device=unique_values.device, + ) + * initial_accumulator + ) + unique_values = torch.cat((unique_values, optstate), dim=1).contiguous() + unique_values = unique_values.reshape(-1) + + n = unique_indices.shape[0] + scores = ( + torch.ones(n, dtype=torch.uint64, device=unique_indices.device) + if args.cache_algorithm == "lfu" + else None + ) + cur_table.insert(n, unique_indices, unique_values, scores) + + return var + + +def create_split_table_batched_embeddings(args, device): + optimizer = get_fbgemm_optimizer(args.optimizer_type) + D = args.embedding_dim + Es = args.num_embeddings_per_feature + cache_alg = ( + CacheAlgorithm.LRU if args.cache_algorithm == "lru" else CacheAlgorithm.LFU + ) + + if args.caching: + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + e, + D, + EmbeddingLocation.MANAGED_CACHING, + ComputeDevice.CUDA, + ) + for e in Es + ], + optimizer=optimizer, + weights_precision=get_fbgemm_precision(args.emb_precision), + stochastic_rounding=False, + cache_load_factor=args.gpu_ratio, + cache_algorithm=cache_alg, + pooling_mode=PoolingMode.NONE, + output_dtype=get_fbgemm_precision(args.output_dtype), + device=device, + learning_rate=args.learning_rate, + eps=args.eps, + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + bounds_check_mode=BoundsCheckMode.NONE, + stats_reporter_config=StdLogStatsReporterConfig(report_interval), + record_cache_metrics=RecordCacheMetrics(True, False), + ).cuda() + else: + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + e, + D, + EmbeddingLocation.MANAGED, + ComputeDevice.CUDA, + ) + for e in Es + ], + optimizer=optimizer, + weights_precision=get_fbgemm_precision(args.emb_precision), + stochastic_rounding=False, + pooling_mode=PoolingMode.NONE, + output_dtype=get_fbgemm_precision(args.output_dtype), + device=device, + learning_rate=args.learning_rate, + eps=args.eps, + weight_decay=args.weight_decay, + beta1=args.beta1, + beta2=args.beta2, + bounds_check_mode=BoundsCheckMode.NONE, + ).cuda() + return emb + + +def warmup_gpu(device="cuda"): + # 1. compute unit + a = torch.randn(10, 16384, 2048, device=device) + b = torch.randn(10, 2048, 16384, device=device) + for _ in range(5): + torch.matmul(a, b) + torch.cuda.synchronize() + + # 2. copy engine + d_cpu = torch.randn(10, 1024, 1024) + d_gpu = torch.empty_like(d_cpu, device=device) + for _ in range(5): + # CPU -> GPU + d_gpu.copy_(d_cpu, non_blocking=True) + torch.cuda.synchronize() + # GPU -> CPU + d_cpu.copy_(d_gpu, non_blocking=True) + torch.cuda.synchronize() + + +def benchmark_one_iteration(model, sparse_feature): + start_event = torch.cuda.Event(enable_timing=True) + mid_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + start_event.record() + output = model(sparse_feature.values(), sparse_feature.offsets()) + mid_event.record() + grad = torch.empty_like(output) + output.backward(grad) + end_event.record() + + torch.cuda.synchronize() + forward_latency = start_event.elapsed_time(mid_event) + backward_latency = mid_event.elapsed_time(end_event) + iteration_latency = start_event.elapsed_time(end_event) + return forward_latency, backward_latency, iteration_latency + + +def benchmark_train_eval(model, sparse_features, timer, args): + model.train() + + timer.start() + for i in range(args.num_iterations): + sparse_feature = sparse_features[i] + output = model(sparse_feature.values(), sparse_feature.offsets()) + grad = torch.empty_like(output) + output.backward(grad) + timer.stop() + train_latency = timer.elapsed_time() / args.num_iterations + + timer.start() + for i in range(args.num_iterations): + sparse_feature = sparse_features[i] + output = model(sparse_feature.values(), sparse_feature.offsets()) + timer.stop() + train_forward_latency = timer.elapsed_time() / args.num_iterations + + train_backward_latency = train_latency - train_forward_latency + + model.eval() + timer.start() + for i in range(args.num_iterations): + sparse_feature = sparse_features[i] + output = model(sparse_feature.values(), sparse_feature.offsets()) + timer.stop() + eval_latency = timer.elapsed_time() / args.num_iterations + + return train_latency, train_forward_latency, train_backward_latency, eval_latency + def append_to_json(file_path, data): try: @@ -115,7 +664,19 @@ def count_tensor_to_dict(x, d): d[key] += 1 -def test(args): +def clear_cache(args, dynamic_emb, torchrec_emb): + assert args.caching + dynamic_emb.reset_cache_states() + torchrec_emb.reset_cache_states() + + +@record +def main(): + args = parse_args() + print("Arguments:") + for arg, value in vars(args).items(): + print(f"{arg}: {value}") + backend = "nccl" dist.init_process_group(backend=backend) @@ -157,37 +718,67 @@ def test(args): beta2=args.beta2, ) - num_iterations = args.num_iterations - - warm_iters = 10 - sparse_features = [] - for i in range(num_iterations * 2 + warm_iters): - sparse_features.append( - generate_dynamic_sequence_sparse_feature(args.batch_size) - ) - - for i in range(warm_iters): - sparse_feature = sparse_features[i] - res = var(sparse_feature.values(), sparse_feature.offsets()) - grad = torch.ones_like(res) - res.backward(grad) - - # forward - torch.cuda.synchronize() - start_time = time.perf_counter() - for i in range(num_iterations): - sparse_feature = sparse_features[i + warm_iters] - res = var(sparse_feature.values(), sparse_feature.offsets()) - - torch.cuda.synchronize() - end_time = time.perf_counter() - average_iteration_time_fw = (end_time - start_time) / args.num_iterations * 1000 - print(f"Total time taken: {end_time - start_time:.4f} seconds") - print(f"Average time per iteration(forward): {average_iteration_time_fw:.4f} ms") + torchrec_emb = create_split_table_batched_embeddings(args, device) + cache_miss_counter_torchrec = None + + if args.caching: + var.set_record_cache_metrics(True) + clear_cache(args, var, torchrec_emb) + + warmup_gpu(device) + for i in range(0, args.num_iterations, report_interval): + for j in range(report_interval): + ( + forward_latency, + backward_latency, + iteration_latency, + ) = benchmark_one_iteration(var, sparse_features[i + j]) + cache_info = "" + if args.caching: + cache_metrics = var.caches[0].cache_metrics + unique_num = cache_metrics[0].item() + cache_hit = cache_metrics[1].item() + cache_miss = unique_num - cache_hit + hit_rate = 1.0 * cache_hit / unique_num + cache_info = f"cache_miss:{cache_miss}, unique: {unique_num}, hit_rate: {hit_rate:.8f}," + print( + f"dynamicemb: Iteration {i + j}, forward: {forward_latency:.3f} ms, backward: {backward_latency:.3f} ms, " + f"total: {iteration_latency:.3f} ms, cache info: {cache_info}" + ) + + for j in range(report_interval): + ( + forward_latency, + backward_latency, + iteration_latency, + ) = benchmark_one_iteration(torchrec_emb, sparse_features[i + j]) + cache_info = "" + if args.caching: + cache_miss_counter_ = torchrec_emb.get_cache_miss_counter().clone() + # table_wise_cache_miss_ = torchrec_emb.get_table_wise_cache_miss().clone() + if cache_miss_counter_torchrec is not None: + cache_miss_counter_incerment = ( + cache_miss_counter_ - cache_miss_counter_torchrec + ) + else: + cache_miss_counter_incerment = torch.tensor([0, 0]) + # if table_wise_cache_miss is not None: + # table_wise_cache_miss_increment = table_wise_cache_miss_ - table_wise_cache_miss + # else: + # table_wise_cache_miss_increment = torch.tensor([0]) + cache_info = f"cache miss: {cache_miss_counter_incerment[1].item()}" + cache_miss_counter_torchrec = cache_miss_counter_ + + print( + f"torchrec: Iteration {i + j}, forward: {forward_latency:.3f} ms, backward: {backward_latency:.3f} ms, " + f"total: {iteration_latency:.3f} ms, cache info: {cache_info}" + ) + + if args.caching: + var.set_record_cache_metrics(False) + torchrec_emb.record_cache_metrics = RecordCacheMetrics(False, False) + clear_cache(args, var, torchrec_emb) - # forward + backward - torch.cuda.synchronize() - start_time = time.perf_counter() torch.cuda.profiler.start() for i in range(num_iterations): sparse_feature = sparse_features[i + warm_iters + num_iterations] @@ -205,14 +796,25 @@ def test(args): ) test_result = { - "use_index_dedup": args.use_index_dedup, + "caching": args.caching, + "table_version": args.table_version, "batch_size": args.batch_size, "num_embeddings_per_feature": args.num_embeddings_per_feature, "hbm_for_embeddings": args.hbm_for_embeddings, "optimizer_type": args.optimizer_type, - "forward_overhead": average_iteration_time_fw, - "backward_overhead": average_iteration_time - average_iteration_time_fw, - "totoal_overhead": average_iteration_time, + "feature_distribution-alpha": f"{args.feature_distribution}-{args.alpha}", + "embedding_dim": args.embedding_dim, + "num_iterations": args.num_iterations, + "cache_algorithm": args.cache_algorithm, + "use_index_dedup": args.use_index_dedup, + "eval(torchrec)": torchrec_res[3], + "forward(torchrec)": torchrec_res[1], + "backward(torchrec)": torchrec_res[2], + "train(torchrec)": torchrec_res[0], + "eval(dynamicemb)": dynamicemb_res[3], + "forward(dynamicemb)": dynamicemb_res[1], + "backward(dynamicemb)": dynamicemb_res[2], + "train(dynamicemb)": dynamicemb_res[0], } append_to_json("benchmark_results.json", test_result) diff --git a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.sh b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.sh index 90a1ddaeb..e401aa020 100644 --- a/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.sh +++ b/corelib/dynamicemb/benchmark/benchmark_batched_dynamicemb_tables.sh @@ -3,11 +3,13 @@ export CUDA_VISIBLE_DEVICES=0 declare -A hbm=(["sgd"]=4 ["adam"]=12) - -use_index_dedups=("True") +use_index_dedups=("False") batch_sizes=(65536 1048576) capacities=("8" "64") optimizer_types=("sgd" "adam") +embedding_dims=(128) +alphas=(1.05) +gpu_ratio=0.125 rm benchmark_results.json for use_index_dedup in "${use_index_dedups[@]}"; do @@ -17,15 +19,50 @@ for use_index_dedup in "${use_index_dedups[@]}"; do echo "####" $use_index_dedup $batch_size $capacity ${hbm[$optimizer_type]} $optimizer_type - # ncu -f --target-processes all --export dynamicemb-rep.report --section SchedulerStats --section WarpStateStats --import-source=yes --page raw --set full --profile-from-start no -k regex:"fill_output_with_table_vectors_kernel|initialize_optimizer_state_kernel" \ - # nsys profile -s none -t cuda,nvtx,osrt,mpi,ucx -f true -o dynamicemb$batch_size.qdrep -c cudaProfilerApi --cpuctxsw none --cuda-flush-interval 100 --capture-range-end=stop --cuda-graph-trace=node \ - torchrun --nnodes 1 --nproc_per_node 1 \ - ./benchmark/benchmark_batched_dynamicemb_tables.py \ - --use_index_dedup $use_index_dedup \ - --batch_size $batch_size \ - --num_embeddings_per_feature $capacity \ - --hbm_for_embeddings ${hbm[$optimizer_type]} \ - --optimizer_type $optimizer_type + torchrun --nnodes 1 --nproc_per_node 1 \ + ./benchmark/benchmark_batched_dynamicemb_tables.py \ + --caching \ + --cache_algorithm "lru" \ + --gpu_ratio $gpu_ratio \ + --batch_size $batch_size \ + --num_embeddings_per_feature $capacity \ + --embedding_dim $embedding_dim \ + --hbm_for_embeddings ${hbm[$optimizer_type]} \ + --optimizer_type $optimizer_type \ + --feature_distribution "pow-law" \ + --alpha $alpha \ + --num_iterations 100 \ + --table_version 2 + + torchrun --nnodes 1 --nproc_per_node 1 \ + ./benchmark/benchmark_batched_dynamicemb_tables.py \ + --batch_size $batch_size \ + --num_embeddings_per_feature $capacity \ + --embedding_dim $embedding_dim \ + --hbm_for_embeddings ${hbm[$optimizer_type]} \ + --optimizer_type $optimizer_type \ + --feature_distribution "pow-law" \ + --alpha $alpha \ + --num_iterations 100 \ + --cache_algorithm "lru" \ + --table_version 2 + + + # ncu -f --target-processes all --export de_and_tr-$batch_size-$capacity-$optimizer_type-rep.report --section SchedulerStats --section WarpStateStats --import-source=yes --page raw --set full --profile-from-start no -k regex:"load_or_initialize_" \ + # nsys profile -s none -t cuda,nvtx,osrt,mpi,ucx -f true -o de_and_tr-$batch_size-$capacity-$optimizer_type.qdrep -c cudaProfilerApi --cpuctxsw none --cuda-flush-interval 100 --capture-range-end=stop --cuda-graph-trace=node \ + torchrun --nnodes 1 --nproc_per_node 1 \ + ./benchmark/benchmark_batched_dynamicemb_tables.py \ + --batch_size $batch_size \ + --num_embeddings_per_feature $capacity \ + --hbm_for_embeddings ${hbm[$optimizer_type]} \ + --optimizer_type $optimizer_type \ + --feature_distribution "pow-law" \ + --embedding_dim $embedding_dim \ + --num_iterations 100 \ + --cache_algorithm "lru" \ + --alpha $alpha \ + --table_version 1 + done done done done diff --git a/corelib/dynamicemb/dynamicemb/__init__.py b/corelib/dynamicemb/dynamicemb/__init__.py index 18526f4d7..b29e4fd74 100644 --- a/corelib/dynamicemb/dynamicemb/__init__.py +++ b/corelib/dynamicemb/dynamicemb/__init__.py @@ -18,8 +18,6 @@ BATCH_SIZE_PER_DUMP, DynamicEmbCheckMode, DynamicEmbEvictStrategy, - DynamicEmbInitializerArgs, - DynamicEmbInitializerMode, DynamicEmbPoolingMode, DynamicEmbScoreStrategy, DynamicEmbTableOptions, @@ -29,9 +27,20 @@ string_to_evict_strategy, torch_to_dyn_emb, ) +from .embedding_admission import FrequencyAdmissionStrategy, KVCounter from .optimizer import EmbOptimType, OptimizerArgs +from .types import ( + AdmissionStrategy, + Counter, + DynamicEmbInitializerArgs, + DynamicEmbInitializerMode, +) __all__ = [ + "AdmissionStrategy", + "FrequencyAdmissionStrategy", + "Counter", + "KVCounter", "DynamicEmbCheckMode", "DynamicEmbInitializerArgs", "DynamicEmbInitializerMode", diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_compute_kernel.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_compute_kernel.py index f03ff3d86..db2c3f9e6 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_compute_kernel.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_compute_kernel.py @@ -20,8 +20,9 @@ import torch import torch.distributed as dist from dynamicemb.batched_dynamicemb_tables import BatchedDynamicEmbeddingTables -from dynamicemb.dynamicemb_config import DEFAULT_INDEX_TYPE, DynamicEmbPoolingMode +from dynamicemb.dynamicemb_config import DEFAULT_INDEX_TYPE, DynamicEmbPoolingMode, DynamicEmbTableOptions from dynamicemb.optimizer import string_to_opt_type +from dynamicemb import DynamicEmbInitializerArgs from fbgemm_gpu.split_table_batched_embeddings_ops_training import PoolingMode from torch import nn from torchrec.distributed.batched_embedding_kernel import ( @@ -230,6 +231,14 @@ def _clean_grouped_fused_params(fused_params: Dict[str, Any]): fused_params["optimizer"] = dyn_emb_opt_type +def _fix_dynamicemb_options(dynamicemb_options: Any) -> DynamicEmbTableOptions: + if not isinstance(dynamicemb_options, DynamicEmbTableOptions): + dynamicemb_options = DynamicEmbTableOptions(**dynamicemb_options) + if not isinstance(dynamicemb_options.initializer_args, DynamicEmbInitializerArgs): + dynamicemb_options.initializer_args = DynamicEmbInitializerArgs(**dynamicemb_options.initializer_args) + return dynamicemb_options + + class BatchedDynamicEmbeddingBag( BaseBatchedEmbeddingBag[torch.Tensor], # FusedOptimizerModule # BaseBatchedEmbeddingBag[BatchedDynamicEmbeddingTables, torch.Tensor], # FusedOptimizerModule @@ -255,6 +264,7 @@ def __init__( self._local_rows, self._local_cols, config.embedding_tables ): dynamicemb_options = table.fused_params["dynamicemb_options"] + dynamicemb_options = _fix_dynamicemb_options(dynamicemb_options) dynamicemb_options.dim = local_col dynamicemb_options.max_capacity = local_row if dynamicemb_options.index_type is None: @@ -373,6 +383,7 @@ def __init__( self._local_rows, self._local_cols, config.embedding_tables ): dynamicemb_options = table.fused_params["dynamicemb_options"] + dynamicemb_options = _fix_dynamicemb_options(dynamicemb_options) dynamicemb_options.dim = local_col dynamicemb_options.max_capacity = local_row if dynamicemb_options.index_type is None: @@ -465,3 +476,10 @@ def flush(self) -> None: def purge(self) -> None: self._emb_module.reset_cache_states() + + def forward(self, features) -> torch.Tensor: + return self._emb_module( + features.values(), + features.offsets(), + per_sample_weights=features.weights_or_none(), + ) diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py index d167614d3..e4800081f 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_function.py @@ -13,20 +13,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import List, Optional import torch -from dynamicemb.dynamicemb_config import DynamicEmbPoolingMode, dyn_emb_to_torch +from dynamicemb.dynamicemb_config import ( + DynamicEmbInitializerArgs, + DynamicEmbPoolingMode, + dyn_emb_to_torch, +) +from dynamicemb.initializer import BaseDynamicEmbInitializer +from dynamicemb.key_value_table import ( + Cache, + KeyValueTableCachingFunction, + KeyValueTableFunction, + Storage, +) from dynamicemb.optimizer import BaseDynamicEmbeddingOptimizer +from dynamicemb.types import Counter from dynamicemb.unique_op import UniqueOp from dynamicemb_extensions import ( DynamicEmbTable, + EvictStrategy, + find_and_initialize, find_or_insert, + get_table_range, lookup_backward, lookup_backward_dense, lookup_backward_dense_dedup, lookup_forward, lookup_forward_dense, + lookup_forward_dense_eval, ) @@ -51,7 +67,9 @@ def forward( unique_op: UniqueOp, device: torch.device, optimizer: BaseDynamicEmbeddingOptimizer, - *args + training: bool, + eval_initializers: List[DynamicEmbInitializerArgs], + *args, ): # TODO: remove unnecessary params. # TODO:need check dimension is right @@ -114,13 +132,22 @@ def forward( num_unique_indices, dims[i], dtype=tmp_value_type_torch, device=device ) - find_or_insert( - tables[i], - num_unique_indices, - unique_indices, - tmp_unique_embs, - scores[i], - ) + if training: + find_or_insert( + tables[i], + num_unique_indices, + unique_indices, + tmp_unique_embs, + scores[i], + ) + else: + find_and_initialize( + tables[i], + num_unique_indices, + unique_indices, + tmp_unique_embs, + eval_initializers[i].as_ctype(), + ) unique_embedding_list.append(tmp_unique_embs) @@ -162,19 +189,20 @@ def forward( accum_D += dims[i] * (num_embeddings // batch_size) assert num_embeddings % batch_size == 0 - backward_tensors = [indices, offsets] - ctx.save_for_backward(*backward_tensors) - ctx.tables = tables - ctx.unique_indices_list = unique_indices_list - ctx.inverse_indices_list = inverse_indices_list - ctx.biased_offsets_list = biased_offsets_list - ctx.dims = dims - ctx.batch_size = batch_size - ctx.feature_num = feature_num - ctx.feature_table_map = feature_table_map - ctx.device = device - ctx.optimizer = optimizer - ctx.scores = scores + if training: + backward_tensors = [indices, offsets] + ctx.save_for_backward(*backward_tensors) + ctx.tables = tables + ctx.unique_indices_list = unique_indices_list + ctx.inverse_indices_list = inverse_indices_list + ctx.biased_offsets_list = biased_offsets_list + ctx.dims = dims + ctx.batch_size = batch_size + ctx.feature_num = feature_num + ctx.feature_table_map = feature_table_map + ctx.device = device + ctx.optimizer = optimizer + ctx.scores = scores return embs @@ -193,6 +221,7 @@ def backward(ctx, grad): device = ctx.device optimizer = ctx.optimizer table_num = len(tables) + combiner = ctx.combiner offsets_list_per_table = [] for i in range(table_num): @@ -245,17 +274,16 @@ def backward(ctx, grad): batch_size, feature_num_per_table[i], offsets_list_per_table[i][-1].item(), + combiner, ) unique_grads_per_table = [] for i, unique_grad in enumerate(unique_backward_grads_per_table): unique_grads_per_table.append(unique_grad.reshape(-1, dims[i])) - scores = ctx.scores + optimizer.update(tables, unique_indices_list, unique_grads_per_table) - optimizer.update(tables, unique_indices_list, unique_grads_per_table, scores) - - return (None,) * 17 + return (None,) * 19 class DynamicEmbeddingFunction(torch.autograd.Function): @@ -278,7 +306,9 @@ def forward( unique_op: UniqueOp, device: torch.device, optimizer: BaseDynamicEmbeddingOptimizer, - *args + training: bool, + eval_initializers: List[DynamicEmbInitializerArgs], + *args, ): # TODO:need check dimension is right table_num = len(tables) @@ -288,85 +318,107 @@ def forward( batch_size = feature_batch_size // feature_num assert feature_batch_size % feature_num == 0 - d_unique_offsets = torch.zeros(table_num + 1, dtype=torch.uint64, device=device) - h_unique_offsets = torch.empty(table_num + 1, dtype=torch.uint64, device="cpu") - table_offsets = torch.empty(table_num + 1, dtype=offsets.dtype, device=device) - # table' dtype - unique_embs = torch.empty( - indices.shape[0], dim, dtype=embedding_dtype, device=device - ) - # output' dtype - output_embs = torch.empty( - indices.shape[0], dim, dtype=output_dtype, device=device - ) - - # #TODO: if global dedup is done: - # if use_index_dedup: - # lookup_forward_dense( - # tables, - # indices, - # offsets, - # table_offsets_in_feature, - # table_num, - # batch_size, - # dim, - # h_unique_offsets, # used in backward, actually plays a role as table_offsets. - # unique_embs, # serve as a tmp buffer. - # output_embs) - # else: - # TODO:in our case , maybe uint32 is enough for reverse_idx - reverse_idx = torch.empty_like(indices, dtype=torch.uint64, device=device) - unique_idx = torch.empty_like(indices, dtype=indices.dtype, device=device) - h_unique_nums = torch.empty(table_num, dtype=torch.uint64, device="cpu") - d_unique_nums = torch.empty(table_num, dtype=torch.uint64, device=device) - lookup_forward_dense( - tables, - indices, - offsets, - scores, - table_offsets_in_feature, - table_offsets, - table_num, - batch_size, - dim, - use_index_dedup, - unique_idx, - reverse_idx, - h_unique_nums, - d_unique_nums, - h_unique_offsets, - d_unique_offsets, - unique_embs, - output_embs, - device_num_sms, - unique_op, - ) - if use_index_dedup: - unique_idx_forback = torch.empty( - h_unique_offsets[-1], dtype=indices.dtype, device=device + if training: + d_unique_offsets = torch.zeros( + table_num + 1, dtype=torch.uint64, device=device ) - unique_idx_forback.copy_( - unique_idx[: h_unique_offsets[-1]], non_blocking=True + h_unique_offsets = torch.empty( + table_num + 1, dtype=torch.uint64, device="cpu" + ) + table_offsets = torch.empty( + table_num + 1, dtype=offsets.dtype, device=device + ) + # table' dtype + unique_embs = torch.empty( + indices.shape[0], dim, dtype=embedding_dtype, device=device + ) + # output' dtype + output_embs = torch.empty( + indices.shape[0], dim, dtype=output_dtype, device=device ) - backward_tensors = [indices, offsets] - ctx.save_for_backward(*backward_tensors) - ctx.tables = tables - ctx.dim = dim - ctx.device = device - ctx.optimizer = optimizer - - # optimize need - ctx.h_unique_offsets = h_unique_offsets - ctx.table_offsets = table_offsets - ctx.use_index_dedup = use_index_dedup - ctx.device_num_sms = device_num_sms - if use_index_dedup: - ctx.reverse_idx = reverse_idx - ctx.unique_idx_forback = unique_idx_forback - ctx.scores = scores - - return output_embs + # #TODO: if global dedup is done: + # if use_index_dedup: + # lookup_forward_dense( + # tables, + # indices, + # offsets, + # table_offsets_in_feature, + # table_num, + # batch_size, + # dim, + # h_unique_offsets, # used in backward, actually plays a role as table_offsets. + # unique_embs, # serve as a tmp buffer. + # output_embs) + # else: + # TODO:in our case , maybe uint32 is enough for reverse_idx + reverse_idx = torch.empty_like(indices, dtype=torch.uint64, device=device) + unique_idx = torch.empty_like(indices, dtype=indices.dtype, device=device) + h_unique_nums = torch.empty(table_num, dtype=torch.uint64, device="cpu") + d_unique_nums = torch.empty(table_num, dtype=torch.uint64, device=device) + lookup_forward_dense( + tables, + indices, + offsets, + scores, + table_offsets_in_feature, + table_offsets, + table_num, + batch_size, + dim, + use_index_dedup, + unique_idx, + reverse_idx, + h_unique_nums, + d_unique_nums, + h_unique_offsets, + d_unique_offsets, + unique_embs, + output_embs, + device_num_sms, + unique_op, + ) + if use_index_dedup: + unique_idx_forback = torch.empty( + h_unique_offsets[-1], dtype=indices.dtype, device=device + ) + unique_idx_forback.copy_( + unique_idx[: h_unique_offsets[-1]], non_blocking=True + ) + unique_emb_forback = unique_embs[: h_unique_offsets[-1], :] + + backward_tensors = [indices, offsets] + ctx.save_for_backward(*backward_tensors) + ctx.tables = tables + ctx.dim = dim + ctx.device = device + ctx.optimizer = optimizer + + # optimize need + ctx.h_unique_offsets = h_unique_offsets + ctx.table_offsets = table_offsets + ctx.use_index_dedup = use_index_dedup + ctx.device_num_sms = device_num_sms + if use_index_dedup: + ctx.reverse_idx = reverse_idx + ctx.unique_idx_forback = unique_idx_forback + ctx.unique_emb_forback = unique_emb_forback + ctx.scores = scores + + return output_embs + else: + return lookup_forward_dense_eval( + tables, + indices, + offsets, + table_offsets_in_feature, + embedding_dtype, + table_num, + batch_size, + dim, + device, + [initializer.as_ctype() for initializer in eval_initializers], + ).to(output_dtype) @staticmethod def backward(ctx, grads): @@ -379,17 +431,15 @@ def backward(ctx, grads): device = ctx.device tables = ctx.tables optimizer = ctx.optimizer - use_index_dedup = ctx.use_index_dedup - device_num_sms = ctx.device_num_sms - if use_index_dedup: - reverse_idx = ctx.reverse_idx - unique_idx_forback = ctx.unique_idx_forback table_num = len(tables) unique_indices_list = [] unique_grads_list = [] - - if use_index_dedup: + if ctx.use_index_dedup: + device_num_sms = ctx.device_num_sms + reverse_idx = ctx.reverse_idx + unique_idx_forback = ctx.unique_idx_forback + ctx.unique_emb_forback unique_grads = torch.zeros( h_unique_offsets[-1], dim, dtype=grads.dtype, device=device ) @@ -433,7 +483,228 @@ def backward(ctx, grads): unique_grads[h_unique_offsets[i] : h_unique_offsets[i + 1], :] ) - scores = ctx.scores # optimizer: update tables. - optimizer.update(tables, unique_indices_list, unique_grads_list, scores) + optimizer.update(tables, unique_indices_list, unique_grads_list) + return (None,) * 19 + + +def dynamicemb_prefetch( + indices: torch.Tensor, + offsets: torch.Tensor, + caches: List[Optional[Cache]], + storages: List[Storage], + feature_offsets: torch.Tensor, + initializers: List[BaseDynamicEmbInitializer], + unique_op, + training: bool = True, + forward_stream: Optional[torch.cuda.Stream] = None, +): + table_num = len(storages) + assert table_num != 0 + caching = caches[0] is not None + + indices_table_range = get_table_range(offsets, feature_offsets) + if training or caching: + ( + unique_indices, + inverse, + unique_indices_table_range, + h_unique_indices_table_range, + _, + ) = segmented_unique(indices, indices_table_range, unique_op) + # TODO: only return device unique_indices_table_range + # h_unique_indices_table_range = unique_indices_table_range.cpu() + else: + h_unique_indices_table_range = indices_table_range.cpu() + unique_indices = indices + + for i in range(table_num): + begin = h_unique_indices_table_range[i] + end = h_unique_indices_table_range[i + 1] + unique_indices_per_table = unique_indices[begin:end] + + KeyValueTableFunction.prefetch( + caches[i], + storages[i], + unique_indices_per_table, + initializers[i], + training, + forward_stream, + ) + + +class DynamicEmbeddingFunctionV2(torch.autograd.Function): + @staticmethod + def forward( + ctx, + indices: torch.Tensor, + offsets: torch.Tensor, + caches: List[Optional[Cache]], + storages: List[Storage], + feature_offsets: torch.Tensor, + output_dtype: torch.dtype, + initializers: List[BaseDynamicEmbInitializer], + optimizer: BaseDynamicEmbeddingOptimizer, + unique_op, + enable_prefetch: bool = False, + input_dist_dedup: bool = False, + training: bool = True, + admit_strategy=None, + evict_strategy=None, + frequency_counters: Optional[torch.Tensor] = None, + admission_counter: Optional[List[Counter]] = None, + *args, + ): + table_num = len(storages) + assert table_num != 0 + emb_dtype = storages[0].embedding_dtype() + emb_dim = storages[0].embedding_dim() + caching = caches[0] is not None + + frequency_counts_int64 = None + if frequency_counters is not None: + frequency_counts_int64 = frequency_counters.long() + + lfu_accumulated_frequency = None + indices_table_range = get_table_range(offsets, feature_offsets) + if training or caching: + ( + unique_indices, + inverse, + unique_indices_table_range, + h_unique_indices_table_range, + lfu_accumulated_frequency, + ) = segmented_unique( + indices, + indices_table_range, + unique_op, + EvictStrategy(evict_strategy.value) if evict_strategy else None, + frequency_counts_int64, + ) + # TODO: only return device unique_indices_table_range + # h_unique_indices_table_range = unique_indices_table_range.cpu() + else: + h_unique_indices_table_range = indices_table_range.cpu() + unique_indices = indices + + unique_embs = torch.empty( + unique_indices.shape[0], emb_dim, dtype=emb_dtype, device=indices.device + ) + + for i in range(table_num): + begin = h_unique_indices_table_range[i] + end = h_unique_indices_table_range[i + 1] + unique_indices_per_table = unique_indices[begin:end] + unique_embs_per_table = unique_embs[begin:end, :] + # Slice lfu_accumulated_frequency to match the table + lfu_accumulated_frequency_per_table = ( + lfu_accumulated_frequency[begin:end] + if lfu_accumulated_frequency is not None + and lfu_accumulated_frequency.numel() > 0 + else None + ) + + if caching: + KeyValueTableCachingFunction.lookup( + caches[i], + storages[i], + unique_indices_per_table, + unique_embs_per_table, + initializers[i], + enable_prefetch, + training, + EvictStrategy(evict_strategy.value) if evict_strategy else None, + lfu_accumulated_frequency_per_table, + admit_strategy, + admission_counter[i] if admission_counter else None, + ) + else: + KeyValueTableFunction.lookup( + storages[i], + unique_indices_per_table, + unique_embs_per_table, + initializers[i], + training, + EvictStrategy(evict_strategy.value) if evict_strategy else None, + lfu_accumulated_frequency_per_table, + admit_strategy, + admission_counter[i] if admission_counter else None, + ) + + if training or caching: + output_embs = torch.empty( + indices.shape[0], emb_dim, dtype=output_dtype, device=indices.device + ) + output_embs = unique_embs[inverse] + else: + output_embs = unique_embs + + if training: + # save context + backward_tensors = [ + indices, + ] + ctx.save_for_backward(*backward_tensors) + ctx.input_dist_dedup = input_dist_dedup + if input_dist_dedup: + ctx.unique_indices = unique_indices + ctx.unique_embs = unique_embs + ctx.inverse = inverse + ctx.indices_table_range = indices_table_range + ctx.h_indices_table_range = indices_table_range.cpu() + ctx.h_unique_indices_table_range = h_unique_indices_table_range + ctx.unique_indices_table_range = unique_indices_table_range + ctx.caches = caches + ctx.storages = storages + ctx.optimizer = optimizer + ctx.enable_prefetch = enable_prefetch + + return output_embs + + @staticmethod + def backward(ctx, grads): + # parse context + (indices,) = ctx.saved_tensors + indices_table_range = ctx.indices_table_range + h_indices_table_range = ctx.h_indices_table_range + h_unique_indices_table_range = ctx.h_unique_indices_table_range + ctx.unique_indices_table_range + caches = ctx.caches + storages = ctx.storages + optimizer = ctx.optimizer + caching = caches[0] is not None + + input_dist_dedup = ctx.input_dist_dedup + if input_dist_dedup: + unique_indices = ctx.unique_indices + unique_embs = ctx.unique_embs + ctx.inverse + unique_indices, unique_embs = reduce_grads( + indices, grads, indices_table_range, h_indices_table_range + ) + optimizer.step() + table_num = len(storages) + for i in range(table_num): + begin = h_unique_indices_table_range[i] + end = h_unique_indices_table_range[i + 1] + unique_indices_per_table = unique_indices[begin:end] + unique_embs_per_table = unique_embs[begin:end, :] + + if caching: + KeyValueTableCachingFunction.update( + caches[i], + storages[i], + unique_indices_per_table, + unique_embs_per_table, + optimizer, + ctx.enable_prefetch, + ) + else: + KeyValueTableFunction.update( + storages[i], + unique_indices_per_table, + unique_embs_per_table, + optimizer, + ) + return (None,) * 17 diff --git a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py index c51864bdd..ba38bb34e 100644 --- a/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py +++ b/corelib/dynamicemb/dynamicemb/batched_dynamicemb_tables.py @@ -14,26 +14,28 @@ # limitations under the License. import enum +import logging import warnings +from copy import deepcopy from dataclasses import dataclass, field +from functools import partial from itertools import accumulate -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, cast import torch # usort:skip import torch.distributed as dist -from dynamicemb.batched_dynamicemb_function import * +from dynamicemb.batched_dynamicemb_function import ( + DynamicEmbeddingBagFunction, + DynamicEmbeddingFunctionV2, + dynamicemb_prefetch, +) from dynamicemb.dynamicemb_config import * +from dynamicemb.initializer import * +from dynamicemb.key_value_table import Cache, KeyValueTable, Storage from dynamicemb.optimizer import * from dynamicemb.unique_op import UniqueOp -from dynamicemb_extensions import ( - DynamicEmbTable, - OptimizerType, - count_matched, - device_timestamp, - dyn_emb_capacity, - dyn_emb_cols, - export_batch_matched, -) +from dynamicemb.utils import tabulate +from dynamicemb_extensions import DynamicEmbTable, OptimizerType, device_timestamp from torch import Tensor, nn # usort:skip @@ -111,8 +113,147 @@ class CowClipDefinition: lower_bound: float = 0.0 +def encode_meta_json_file_path(root_path: str, table_name: str) -> str: + return os.path.join(root_path, f"{table_name}_opt_args.json") + + +def encode_checkpoint_file_path( + root_path: str, table_name: str, rank: int, world_size: int, item: str +) -> str: + assert item in ["keys", "values", "scores", "opt_values"] + return os.path.join( + root_path, f"{table_name}_emb_{item}.rank_{rank}.world_size_{world_size}" + ) + + +def encode_counter_checkpoint_file_path( + root_path: str, table_name: str, rank: int, world_size: int, item: str +) -> str: + assert item in ["keys", "frequencies"] + return os.path.join( + root_path, f"{table_name}_counter_{item}.rank_{rank}.world_size_{world_size}" + ) + + +def find_files(root_path: str, table_name: str, suffix: str) -> Tuple[List[str], int]: + suffix_to_encode_file_path_func = { + "emb_keys": partial(encode_checkpoint_file_path, item="keys"), + "emb_values": partial(encode_checkpoint_file_path, item="values"), + "emb_scores": partial(encode_checkpoint_file_path, item="scores"), + "opt_values": partial(encode_checkpoint_file_path, item="opt_values"), + "counter_keys": partial(encode_counter_checkpoint_file_path, item="keys"), + "counter_frequencies": partial( + encode_counter_checkpoint_file_path, item="frequencies" + ), + } + if suffix not in suffix_to_encode_file_path_func: + raise RuntimeError(f"Invalid suffix: {suffix}") + encode_file_path_func = suffix_to_encode_file_path_func[suffix] + + import glob + + # v2 version + files = glob.glob(encode_file_path_func(root_path, table_name, "*", "*")) + if len(files) == 0: + return [], 0 + files = sorted(files) + world_size = int(files[0].split(".")[-1].split("_")[-1]) + if len(files) != world_size: + raise RuntimeError( + f"Checkpoints is corrupted. Found {len(files)} under path {root_path} for table {table_name}, but the number of checkpointed world size is {world_size}." + ) + + for i in range(world_size): + expected_file_path = encode_file_path_func(root_path, table_name, i, world_size) + if expected_file_path not in set(files): + raise RuntimeError( + f"Checkpoints is corrupted. Expected file path {expected_file_path} for table {table_name}, but it is not found." + ) + + return files, len(files) + + +def get_loading_files( + root_path: str, + name: str, + rank: int, + world_size: int, +) -> Tuple[List[str], List[str], List[str], List[str], int, int]: + if not os.path.exists(root_path): + raise RuntimeError(f"can't find path to load, path:", root_path) + + key_files, num_key_files = find_files(root_path, name, "emb_keys") + value_files, num_value_files = find_files(root_path, name, "emb_values") + score_files, num_score_files = find_files(root_path, name, "emb_scores") + opt_files, num_opt_files = find_files(root_path, name, "opt_values") + + if num_key_files != num_value_files: + assert ( + num_key_files > 0 + ), "No key files found under path {root_path} for table {name}" + raise RuntimeError( + f"The number of key files under path {root_path} for table {name} does not match the number of value files." + ) + + counter_key_files, num_counter_key_files = find_files( + root_path, name, "counter_keys" + ) + counter_freq_files, num_counter_freq_files = find_files( + root_path, name, "counter_frequencies" + ) + + if num_counter_key_files != num_counter_freq_files: + raise RuntimeError( + f"The number of key files of admission counter under path {root_path} for table {name} does not match the number of frequency files({num_counter_key_files}/{num_counter_freq_files})." + ) + + if num_counter_key_files > 0 and num_counter_key_files != num_key_files: + raise RuntimeError( + f"The number of key files under path {root_path} for table {name} does not match the number of keys files of admission counter({num_key_files}/{num_counter_key_files})." + ) + + if world_size == num_key_files: + return ( + [encode_checkpoint_file_path(root_path, name, rank, world_size, "keys")], + [encode_checkpoint_file_path(root_path, name, rank, world_size, "values")], + [encode_checkpoint_file_path(root_path, name, rank, world_size, "scores")] + if num_score_files == num_key_files + else [], + [ + encode_checkpoint_file_path( + root_path, name, rank, world_size, "opt_values" + ) + ] + if num_opt_files == num_key_files + else [], + [ + encode_counter_checkpoint_file_path( + root_path, name, rank, world_size, "keys" + ) + ] + if num_counter_key_files == num_key_files + else [], + [ + encode_counter_checkpoint_file_path( + root_path, name, rank, world_size, "frequencies" + ) + ] + if num_counter_freq_files == num_key_files + else [], + ) + # TODO: support skipping files. + return ( + key_files, + value_files, + score_files, + opt_files, + counter_key_files, + counter_freq_files, + ) + + def _export_matched_and_gather( - dynamic_table: DynamicEmbTable, + dynamic_table: KeyValueTable, threshold: int, pg: Optional[dist.ProcessGroup] = None, batch_size: int = BATCH_SIZE_PER_DUMP, @@ -123,7 +264,7 @@ def _export_matched_and_gather( device = torch.device(f"cuda:{torch.cuda.current_device()}") d_num_matched = torch.zeros(1, dtype=torch.uint64, device=device) - count_matched(dynamic_table, threshold, d_num_matched) + dynamic_table.count_matched(threshold, d_num_matched) gathered_num_matched = [ torch.tensor(0, dtype=torch.int64, device=device) for _ in range(world_size) @@ -131,36 +272,38 @@ def _export_matched_and_gather( dist.all_gather(gathered_num_matched, d_num_matched.to(dtype=torch.int64), group=pg) total_matched = sum([t.item() for t in gathered_num_matched]) # t is on device. - key_dtype = dyn_emb_to_torch(dynamic_table.key_type()) - value_dtype = dyn_emb_to_torch(dynamic_table.value_type()) - dim: int = dyn_emb_cols(dynamic_table) + key_dtype = dynamic_table.key_type() + value_dtype = dynamic_table.value_type() + dim: int = dynamic_table.embedding_dim() + total_dim = dynamic_table.value_dim() ret_keys = torch.empty(total_matched, dtype=key_dtype, device="cpu") ret_vals = torch.empty(total_matched * dim, dtype=value_dtype, device="cpu") ret_offset = 0 search_offset = 0 - search_capacity = dyn_emb_capacity(dynamic_table) - batch_size = batch_size if batch_size < search_capacity else search_capacity + search_capacity = dynamic_table.capacity() d_keys = torch.empty(batch_size, dtype=key_dtype, device=device) - d_vals = torch.empty(batch_size * dim, dtype=value_dtype, device=device) + d_embs = torch.empty(batch_size * dim, dtype=value_dtype, device=device) + d_vals = torch.empty(batch_size * total_dim, dtype=value_dtype, device=device) d_count = torch.zeros(1, dtype=torch.uint64, device=device) # Gather keys and values for all ranks gathered_keys = [torch.empty_like(d_keys) for _ in range(world_size)] - gathered_vals = [torch.empty_like(d_vals) for _ in range(world_size)] + gathered_vals = [torch.empty_like(d_embs) for _ in range(world_size)] gathered_counts = [ torch.empty_like(d_count, dtype=torch.int64) for _ in range(world_size) ] while search_offset < search_capacity: - export_batch_matched( - dynamic_table, threshold, batch_size, search_offset, d_count, d_keys, d_vals + dynamic_table.export_batch_matched( + threshold, batch_size, search_offset, d_count, d_keys, d_vals ) + d_embs = d_vals.view(batch_size, total_dim)[:, :dim].reshape(-1) dist.all_gather(gathered_keys, d_keys, group=pg) - dist.all_gather(gathered_vals, d_vals, group=pg) + dist.all_gather(gathered_vals, d_embs, group=pg) dist.all_gather(gathered_counts, d_count.to(dtype=torch.int64), group=pg) for d_keys_, d_vals_, d_count_ in zip( @@ -180,41 +323,42 @@ def _export_matched_and_gather( def _export_matched( - dynamic_table: DynamicEmbTable, + dynamic_table: KeyValueTable, threshold: int, batch_size: int = BATCH_SIZE_PER_DUMP, ) -> Tuple[Tensor, Tensor]: device = torch.device(f"cuda:{torch.cuda.current_device()}") d_num_matched = torch.zeros(1, dtype=torch.uint64, device=device) - count_matched(dynamic_table, threshold, d_num_matched) + dynamic_table.count_matched(threshold, d_num_matched) total_matched = d_num_matched.cpu().item() - key_dtype = dyn_emb_to_torch(dynamic_table.key_type()) - value_dtype = dyn_emb_to_torch(dynamic_table.value_type()) - dim: int = dyn_emb_cols(dynamic_table) + key_dtype = dynamic_table.key_type() + value_dtype = dynamic_table.value_type() + dim: int = dynamic_table.embedding_dim() + total_dim = dynamic_table.value_dim() ret_keys = torch.empty(total_matched, dtype=key_dtype, device="cpu") ret_vals = torch.empty(total_matched * dim, dtype=value_dtype, device="cpu") ret_offset = 0 search_offset = 0 - search_capacity = dyn_emb_capacity(dynamic_table) + search_capacity = dynamic_table.capacity() batch_size = batch_size if batch_size < search_capacity else search_capacity d_keys = torch.empty(batch_size, dtype=key_dtype, device=device) - d_vals = torch.empty(batch_size * dim, dtype=value_dtype, device=device) + d_vals = torch.empty(batch_size * total_dim, dtype=value_dtype, device=device) d_count = torch.zeros(1, dtype=torch.uint64, device=device) while search_offset < search_capacity: - export_batch_matched( - dynamic_table, threshold, batch_size, search_offset, d_count, d_keys, d_vals + dynamic_table.export_batch_matched( + threshold, batch_size, search_offset, d_count, d_keys, d_vals ) h_count = d_count.cpu().item() ret_keys[ret_offset : ret_offset + h_count] = d_keys[0:h_count].cpu() - ret_vals[ret_offset * dim : (ret_offset + h_count) * dim] = d_vals[ - 0 : h_count * dim - ].cpu() + ret_vals[ret_offset * dim : (ret_offset + h_count) * dim] = ( + d_vals.view(batch_size, total_dim)[:h_count, :dim].reshape(-1).cpu() + ) ret_offset += h_count search_offset += batch_size @@ -223,7 +367,98 @@ def _export_matched( return ret_keys, ret_vals -class BatchedDynamicEmbeddingTables(nn.Module): +def _print_memory_consume( + table_names, dynamicemb_options, optimizer, device_id +) -> None: + subtitle = [ + "", + "total", + "embedding", + "optim_state", + "total", + "embedding", + "optim_state", + "total", + "embedding", + "optim_state", + ] + table_consume = [] + table_consume.append(subtitle) + + def _get_optimizer_state_dim(optimizer_type, dim, element_size): + if optimizer_type == OptimizerType.RowWiseAdaGrad: + return 16 // element_size + elif optimizer_type == OptimizerType.Adam: + return dim * 2 + elif optimizer_type == OptimizerType.AdaGrad: + return dim + else: + return 0 + + DTYPE_NUM_BYTES: Dict[torch.dtype, int] = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + } + + def MB_(x) -> int: + return x // (1024 * 1024) + + def KB_(x) -> int: + return x // (1024) + + F = None + + for table_name, table_option in zip(table_names, dynamicemb_options): + element_size = DTYPE_NUM_BYTES[table_option.embedding_dtype] + emb_dim = table_option.dim + if optimizer is not None: + optim_state_dim = optimizer.get_state_dim(emb_dim) + else: + optim_state_dim = _get_optimizer_state_dim( + table_option.optimizer_type, emb_dim, element_size + ) + total_dim = emb_dim + optim_state_dim + total_memory = table_option.max_capacity * element_size * total_dim + if F is None: + if total_memory // (1024 * 1024) != 0: + F = MB_ + else: + F = KB_ + local_hbm_for_values = min(table_option.local_hbm_for_values, total_memory) + local_dram_for_values = total_memory - local_hbm_for_values + table_consume.append( + [ + table_name, + F(total_memory), + F(table_option.max_capacity * element_size * emb_dim), + F(table_option.max_capacity * element_size * optim_state_dim), + F(local_hbm_for_values), + F(int(local_hbm_for_values * emb_dim // total_dim)), + F(int(local_hbm_for_values * optim_state_dim // total_dim)), + F(local_dram_for_values), + F(int(local_dram_for_values * emb_dim // total_dim)), + F(int(local_dram_for_values * optim_state_dim // total_dim)), + ] + ) + unit = "MB" if F == MB_ else "KB" + title = [ + "table name", + "", + f"memory({unit})", + "", + "", + f"hbm({unit})/cuda:{device_id}", + "", + "", + f"dram({unit})", + "", + ] + output = "\n\n" + tabulate(table_consume, title, sub_headers=True) + print(output) + + +class BatchedDynamicEmbeddingTablesV2(nn.Module): """ Dynamic Embedding is based on [HKV](https://github.com/NVIDIA-Merlin/HierarchicalKV/tree/master). Looks up one or more dynamic embedding tables. The module is application for training. @@ -239,6 +474,7 @@ def __init__( table_names: Optional[List[str]] = None, feature_table_map: Optional[List[int]] = None, # [T] use_index_dedup: bool = False, + prefetch_pipeline: bool = False, # we set the arg name same as FBGEMM TBE to align with it pooling_mode: DynamicEmbPoolingMode = DynamicEmbPoolingMode.SUM, output_dtype: torch.dtype = torch.float32, device: torch.device = None, @@ -272,6 +508,9 @@ def __init__( cowclip_regularization: Optional[ CowClipDefinition ] = None, # used by Rowwise Adagrad + # TO align with FBGEMM TBE + *args, + **kwargs, ) -> None: super().__init__() assert len(table_options) >= 1 @@ -287,10 +526,14 @@ def __init__( self.output_dtype = output_dtype self.pooling_mode = pooling_mode self.use_index_dedup = use_index_dedup + self._enable_prefetch = prefetch_pipeline + self.prefetch_stream = None + self.num_prefetch_ahead = 0 self._table_names = table_names self.bounds_check_mode_int: int = bounds_check_mode.value self._create_score() - + self._admit_strategy = self._dynamicemb_options[0].admit_strategy + self._evict_strategy = self._dynamicemb_options[0].evict_strategy.value if device is not None: self.device_id = int(str(device)[-1]) else: @@ -345,19 +588,202 @@ def __init__( self.table_offsets_in_feature.append(idx) old_table_id = table_id self.table_offsets_in_feature.append(self.feature_num) + self.feature_offsets = torch.tensor( + self.table_offsets_in_feature, + device=torch.device(self.device_id), + dtype=torch.int64, + ) for option in self._dynamicemb_options: if option.init_capacity is None: option.init_capacity = option.max_capacity - self._optimizer_type = optimizer - self._tables: List[DynamicEmbTable] = [] - self._create_tables() - # add placeholder require_grad param tensor to enable autograd with int8 weights - # self.placeholder_autograd_tensor = nn.Parameter( - # torch.zeros(0, device=torch.device(self.device_id), dtype=torch.float) - # ) - # TODO: review this code block + self._optimizer: Union[ + BaseDynamicEmbeddingOptimizer, BaseDynamicEmbeddingOptimizerV2 + ] = None + self._create_optimizer( + optimizer, + stochastic_rounding, + gradient_clipping, + max_gradient, + max_norm, + learning_rate, + eps, + initial_accumulator_value, + beta1, + beta2, + weight_decay, + eta, + momentum, + weight_decay_mode, + counter_based_regularization, + cowclip_regularization, + ) + self._storage_externel = table_option.external_storage is not None + self._create_cache_storage() + if self.pooling_mode != DynamicEmbPoolingMode.NONE: + self._tables = [] + for storage in self._storages: + assert isinstance( + storage, KeyValueTable + ), "The storage should be KeyValueTable when pooling mode is not None." + kvtable = cast(KeyValueTable, storage) + self._tables.append(kvtable.table) + self._create_bag_optimizer( + self._optimizer_type, self._optimizer_args, self._tables + ) + self._initializers = [] + self._eval_initializers = [] + self._create_initializers() + + self._admission_counter = [option.admission_counter for option in table_options] + + # TODO:1->10 + self._empty_tensor = nn.Parameter( + torch.empty( + 10, + requires_grad=True, + device=torch.device(self.device_id), + dtype=self.embedding_dtype, + ) + ) + + # new a unique op + # TODO: in our case maybe we can use torch.uint32 + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + count_dtype = torch.long + else: + count_dtype = torch.uint64 + reserve_keys = torch.tensor( + 2, dtype=self.index_type, device=torch.device(self.device_id) + ) + reserve_vals = torch.tensor( + 2, dtype=count_dtype, device=torch.device(self.device_id) + ) + counter = torch.tensor( + 1, dtype=count_dtype, device=torch.device(self.device_id) + ) + self._unique_op = UniqueOp(reserve_keys, reserve_vals, counter, 2) + + def _create_cache_storage(self) -> None: + self._storages: List[Storage] = [] + self._caches: List[Cache] = [] + self._caching = self._dynamicemb_options[0].caching + + for option in self._dynamicemb_options: + if option.training and option.optimizer_type == OptimizerType.Null: + option.optimizer_type = convert_optimizer_type(self._optimizer_type) + elif not option.training and option.optimizer_type != OptimizerType.Null: + option.optimizer_type = OptimizerType.Null + warnings.warn( + "Set OptimizerType to Null as not on training mode.", UserWarning + ) + + if option.caching and option.training: + cache_option = deepcopy(option) + cache_option.bucket_capacity = 1024 + capacity = get_constraint_capacity( + option.local_hbm_for_values, + option.embedding_dtype, + option.dim, + option.optimizer_type, + cache_option.bucket_capacity, + ) + if capacity == 0: + raise ValueError( + "Can't use caching mode as the reserved HBM size is too small." + ) + + cache_option.max_capacity = capacity + cache_option.init_capacity = capacity + self._caches.append(KeyValueTable(cache_option, self._optimizer)) + + storage_option = deepcopy(option) + storage_option.local_hbm_for_values = 0 + PS = storage_option.external_storage + self._storages.append( + PS(storage_option, self._optimizer) + if PS + else KeyValueTable(storage_option, self._optimizer) + ) + else: + self._caches.append(None) + self._storages.append(KeyValueTable(option, self._optimizer)) + + _print_memory_consume( + self._table_names, self._dynamicemb_options, self._optimizer, self.device_id + ) + + def _create_initializers(self) -> None: + for option in self._dynamicemb_options: + initializer = create_initializer_from_args(option.initializer_args) + self._initializers.append(initializer) + eval_initializer = create_initializer_from_args( + option.eval_initializer_args + ) + self._eval_initializers.append(eval_initializer) + + def _create_bag_optimizer( + self, + optimizer_type: EmbOptimType, + optimizer_args: OptimizerArgs, + tables: List[DynamicEmbTable], + ) -> None: + if optimizer_type == EmbOptimType.SGD: + self._bag_optimizer = SGDDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + tables, + ) + elif optimizer_type == EmbOptimType.EXACT_SGD: + self._bag_optimizer = SGDDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + tables, + ) + elif optimizer_type == EmbOptimType.ADAM: + self._bag_optimizer = AdamDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + tables, + ) + elif optimizer_type == EmbOptimType.EXACT_ADAGRAD: + self._bag_optimizer = AdaGradDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + tables, + ) + elif optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + self._bag_optimizer = RowWiseAdaGradDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + tables, + ) + else: + raise ValueError( + f"Not supported optimizer type ,optimizer type = {optimizer_type} {type(optimizer_type)} {optimizer_type.value}." + ) + + def _create_optimizer( + self, + optimizer_type: EmbOptimType, + stochastic_rounding: bool, + gradient_clipping: bool, + max_gradient: float, + max_norm: float, + learning_rate: float, + eps: float, + initial_accumulator_value: float, + beta1: float, + beta2: float, + weight_decay: float, + eta: float, + momentum: float, + weight_decay_mode: WeightDecayMode, + counter_based_regularization: Optional[CounterBasedRegularizationDefinition], + cowclip_regularization: Optional[CowClipDefinition], + ) -> None: + self._optimizer_type = optimizer_type self.stochastic_rounding = stochastic_rounding self.weight_decay_mode = weight_decay_mode @@ -375,7 +801,7 @@ def __init__( ) self._used_rowwise_adagrad_with_counter: bool = ( - optimizer == EmbOptimType.EXACT_ROWWISE_ADAGRAD + optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD and ( weight_decay_mode in (WeightDecayMode.COUNTER, WeightDecayMode.COWCLIP) ) @@ -437,82 +863,28 @@ def __init__( lower_bound=cowclip_regularization.lower_bound, regularization_mode=weight_decay_mode.value, ) - self._optimizer: BaseDynamicEmbeddingOptimizer = None - self._create_optimizer(optimizer, optimizer_args) - - # TODO:1->10 - self._empty_tensor = nn.Parameter( - torch.empty( - 10, - requires_grad=True, - device=torch.device(self.device_id), - dtype=self.embedding_dtype, - ) - ) - - # new a unique op - # TODO: in our case maybe we can use torch.uint32 - reserve_keys = torch.tensor( - 2, dtype=self.index_type, device=torch.device(self.device_id) - ) - reserve_vals = torch.tensor( - 2, dtype=torch.uint64, device=torch.device(self.device_id) - ) - counter = torch.tensor( - 1, dtype=torch.uint64, device=torch.device(self.device_id) - ) - self._unique_op = UniqueOp(reserve_keys, reserve_vals, counter, 2) - - def _create_tables(self) -> None: - for option in self._dynamicemb_options: - if option.training: - if self._optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: - option.optimizer_type = OptimizerType.RowWiseAdaGrad - elif ( - self._optimizer_type == EmbOptimType.SGD - or self._optimizer_type == EmbOptimType.EXACT_SGD - ): - option.optimizer_type = OptimizerType.SGD - elif self._optimizer_type == EmbOptimType.ADAM: - option.optimizer_type = OptimizerType.Adam - elif self._optimizer_type == EmbOptimType.EXACT_ADAGRAD: - option.optimizer_type = OptimizerType.AdaGrad - else: - raise ValueError( - f"Not supported optimizer type ,optimizer type = {self._optimizer_type} {type(self._optimizer_type)} {self._optimizer_type.value}." - ) - self._tables.append(create_dynamicemb_table(option)) + self._optimizer_args = optimizer_args - def _create_optimizer( - self, - optimizer_type: EmbOptimType, - optimizer_args: OptimizerArgs, - ) -> None: if optimizer_type == EmbOptimType.SGD: - self._optimizer = SGDDynamicEmbeddingOptimizer( + self._optimizer = SGDDynamicEmbeddingOptimizerV2( optimizer_args, - self._dynamicemb_options, - self._tables, ) elif optimizer_type == EmbOptimType.EXACT_SGD: - self._optimizer = SGDDynamicEmbeddingOptimizer( + self._optimizer = SGDDynamicEmbeddingOptimizerV2( optimizer_args, - self._dynamicemb_options, - self._tables, ) elif optimizer_type == EmbOptimType.ADAM: - self._optimizer = AdamDynamicEmbeddingOptimizer( + self._optimizer = AdamDynamicEmbeddingOptimizerV2( optimizer_args, - self._dynamicemb_options, - self._tables, ) elif optimizer_type == EmbOptimType.EXACT_ADAGRAD: - self._optimizer = AdaGradDynamicEmbeddingOptimizer( - optimizer_args, self._dynamicemb_options, self._tables + self._optimizer = AdaGradDynamicEmbeddingOptimizerV2( + optimizer_args, ) elif optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: - self._optimizer = RowWiseAdaGradDynamicEmbeddingOptimizer( - optimizer_args, self._dynamicemb_options, self._tables + self._optimizer = RowWiseAdaGradDynamicEmbeddingOptimizerV2( + optimizer_args, + self.embedding_dtype, ) else: raise ValueError( @@ -533,27 +905,64 @@ def split_embedding_weights(self) -> List[Tensor]: return splits def flush(self) -> None: - return + self.num_prefetch_ahead = 0 + if self.pooling_mode == DynamicEmbPoolingMode.NONE and self._caching: + for cache, storage in zip(self._caches, self._storages): + cache.flush(storage) def reset_cache_states(self) -> None: - return + if self.pooling_mode == DynamicEmbPoolingMode.NONE and self._caching: + for cache in self._caches: + cache.reset() @property def table_names(self) -> List[str]: return self._table_names @property - def optimizer(self) -> BaseDynamicEmbeddingOptimizer: - return self._optimizer + def optimizer( + self, + ) -> Union[BaseDynamicEmbeddingOptimizer, BaseDynamicEmbeddingOptimizerV2]: + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + return self._optimizer + else: + return self._bag_optimizer @property - def tables(self) -> List[DynamicEmbTable]: - return self._tables + def tables(self) -> List[KeyValueTable]: + # if use external PS, the users should not get the KeyValueTables + # if self._storage_externel: + # raise RuntimeError( + # "Should not get the internal tables when using external storage." + # ) + return self._storages + + @property + def caches(self) -> List[Cache]: + return self._caches + + def set_record_cache_metrics(self, record: bool) -> None: + for cache in self._caches: + cache.set_record_cache_metrics(record) def set_learning_rate(self, lr: float) -> None: - self._optimizer.set_learning_rate(lr) + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + self._optimizer.set_learning_rate(lr) + else: + self._bag_optimizer.set_learning_rate(lr) return + @property + def enable_prefetch( + self, + ) -> None: + return self._enable_prefetch + + @enable_prefetch.setter + def enable_prefetch(self, value: bool): + self._enable_prefetch = value + self.num_prefetch_ahead = 0 + def forward( self, indices: Tensor, @@ -565,15 +974,19 @@ def forward( batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, total_unique_indices: Optional[int] = None, ) -> List[Tensor]: + if self._enable_prefetch: + self.num_prefetch_ahead -= 1 + if indices.dtype != self.index_type: indices = indices.to(self.index_type) - # offsets is on device, if we want to split the indices, we have to read the offset firstly. - # Jost forward it to DynamicEmbeddingFunction - # return DynamicEmbeddingFunction.apply(indices, offsets, self.table_offsets_in_feature, self.tables, self.total_D, - # self.dims,self.feature_table_map, self.embedding_dtype, self.pooling_mode, torch.device(self.device_id), 1, self._empty_tensor) + if any([not o.training for o in self._dynamicemb_options]) and self.training: + raise RuntimeError( + "BatchedDynamicEmbeddingTables does not support training when some tables are in eval mode." + ) scores = [] + # if self.training: for table_name in self._table_names: if table_name not in self._scores.keys(): raise RuntimeError( @@ -582,25 +995,44 @@ def forward( scores.append(self._scores[table_name]) if self.pooling_mode == DynamicEmbPoolingMode.NONE: - res = DynamicEmbeddingFunction.apply( + for i, cache in enumerate(self._caches): + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = True + table.set_score(self._scores[self.table_names[i]]) + for i, storage in enumerate(self._storages): + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + # if not training and not caching, we don't need to update score. + table.score_update = self.training or self._caching + table.set_score(self._scores[self.table_names[i]]) + res = DynamicEmbeddingFunctionV2.apply( indices, offsets, - self.use_index_dedup, - self.table_offsets_in_feature, - self._tables, - scores, - self.total_D, - self.dims[0], - self.feature_table_map, - self.embedding_dtype, + self._caches, + self._storages, + self.feature_offsets, self.output_dtype, - self.pooling_mode, - self._device_num_sms, - self._unique_op, - torch.device(self.device_id), + self._initializers if self.training else self._eval_initializers, self._optimizer, + self._unique_op, + self._enable_prefetch, + self.use_index_dedup, + self.training, + self._admit_strategy, + self._evict_strategy, + per_sample_weights, # Pass frequency counters as weights + self._admission_counter, self._empty_tensor, ) + for cache in self._caches: + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = False + for storage in self._storages: + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = False else: res = DynamicEmbeddingBagFunction.apply( indices, @@ -618,13 +1050,81 @@ def forward( self._device_num_sms, self._unique_op, torch.device(self.device_id), - self._optimizer, + self._bag_optimizer, + self.training, + [option.eval_initializer_args for option in self._dynamicemb_options], self._empty_tensor, ) - self._update_score() + # We have to update cache's core in eval mode. + if self.training or self._caching: + self._update_score() + return res + def prefetch( + self, + indices: Tensor, + offsets: Tensor, + forward_stream: Optional[torch.cuda.Stream] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> None: + assert ( + self.pooling_mode == DynamicEmbPoolingMode.NONE + ), "only support prefetch for sequence embedding." + assert self._enable_prefetch, "Prefetch is not enabled." + if not self._caching: + logging.warning("Caching is not enabled, prefetch will do nothing.") + if self.prefetch_stream is None and forward_stream is not None: + # Set the prefetch stream to the current stream + self.prefetch_stream = torch.cuda.current_stream() + assert ( + self.prefetch_stream != forward_stream + ), "prefetch_stream and forward_stream should not be the same stream" + + current_stream = torch.cuda.current_stream() + # Record tensors on the current stream + indices.record_stream(current_stream) + offsets.record_stream(current_stream) + + if self._enable_prefetch: + self.num_prefetch_ahead += 1 + assert self.num_prefetch_ahead >= 1, "Prefetch context mismatches." + + prefetch_scores = self._get_prefetch_score() + + for i, cache in enumerate(self._caches): + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = True + table.set_score(prefetch_scores[i]) + for i, storage in enumerate(self._storages): + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = True + table.set_score(prefetch_scores[i]) + + dynamicemb_prefetch( + indices, + offsets, + self._caches, + self._storages, + self.feature_offsets, + self._initializers if self.training else self._eval_initializers, + self._unique_op, + self.training, + forward_stream, + ) + + for cache in self._caches: + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = False + for storage in self._storages: + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = False + def set_score( self, named_score: Dict[str, int], @@ -695,6 +1195,1021 @@ def _update_score(self): elif option.score_strategy == DynamicEmbScoreStrategy.LFU: self._scores[table_name] = 1 + def _get_prefetch_score( + self, + ): + ret_scores = [] + for table_name, option in zip(self._table_names, self._dynamicemb_options): + cur_score = self._scores[table_name] + if ( + self.enable_prefetch + and option.score_strategy == DynamicEmbScoreStrategy.STEP + ): + max_uint64 = (2**64) - 1 + new_score = cur_score + self.num_prefetch_ahead - 1 + if new_score > max_uint64: + warnings.warn( + f"Table '{table_name}' 's score({new_score}) is out of range, reset to 0.", + UserWarning, + ) + new_score = 0 + else: + new_score = cur_score + + ret_scores.append(new_score) + return ret_scores + + def dump( + self, + save_dir: str, + optim: bool = False, + counter: bool = False, + table_names: Optional[List[str]] = None, + pg: Optional[dist.ProcessGroup] = None, + ) -> None: + if table_names is None: + table_names = self._table_names + + if pg is None: + assert dist.is_initialized(), "Distributed is not initialized." + pg = dist.group.WORLD + rank = dist.get_rank(group=pg) + world_size = dist.get_world_size(group=pg) + + self.flush() + for table_name, storage, counter_table in zip( + self._table_names, self._storages, self._admission_counter + ): + if table_name not in set(table_names): + continue + + meta_file_path = encode_meta_json_file_path(save_dir, table_name) + emb_key_path = encode_checkpoint_file_path( + save_dir, table_name, rank, world_size, "keys" + ) + emb_value_path = encode_checkpoint_file_path( + save_dir, table_name, rank, world_size, "values" + ) + emb_score_path = encode_checkpoint_file_path( + save_dir, table_name, rank, world_size, "scores" + ) + opt_value_path = encode_checkpoint_file_path( + save_dir, table_name, rank, world_size, "opt_values" + ) + + if isinstance(storage, KeyValueTable) and not storage._use_score: + dist.barrier() # sync global timestamp + cast(KeyValueTable, storage).update_timestamp() + storage.dump( + meta_file_path, + emb_key_path, + emb_value_path, + emb_score_path, + opt_value_path, + include_optim=optim, + include_meta=(rank == 0), + ) + + if not counter: + continue + + counter_key_path = encode_counter_checkpoint_file_path( + save_dir, table_name, rank, world_size, "keys" + ) + counter_frequency_path = encode_counter_checkpoint_file_path( + save_dir, table_name, rank, world_size, "frequencies" + ) + + if counter_table is not None: + counter_table.dump(counter_key_path, counter_frequency_path) + else: + warnings.warn( + f"Counter table is none and will not dump it for table: {table_name}" + ) + + def load( + self, + save_dir: str, + optim: bool = False, + counter: bool = False, + table_names: Optional[List[str]] = None, + pg: Optional[dist.ProcessGroup] = None, + ): + if table_names is None: + table_names = self._table_names + + if pg is None and not dist.is_initialized(): # for inference load + rank = 0 + world_size = 1 + else: + rank = dist.get_rank(group=pg) + world_size = dist.get_world_size(group=pg) + + for table_name, storage, counter_table in zip( + self._table_names, self._storages, self._admission_counter + ): + if table_name not in set(table_names): + continue + ( + emb_key_files, + emb_value_files, + emb_score_files, + opt_value_files, + counter_key_files, + counter_frequency_files, + ) = get_loading_files( + save_dir, + table_name, + rank=rank, + world_size=world_size, + ) + meta_json_file = encode_meta_json_file_path(save_dir, table_name) + + if isinstance(storage, KeyValueTable) and not storage._use_score: + cast(KeyValueTable, storage).update_timestamp() + num_key_files = len(emb_key_files) + for i in range(num_key_files): + storage.load( + meta_json_file, + emb_key_files[i], + emb_value_files[i], + emb_score_files[i] if len(emb_score_files) > 0 else None, + opt_value_files[i] if len(opt_value_files) > 0 else None, + include_optim=optim, + ) + + if not counter: + continue + if counter_table is None: + warnings.warn( + f"Counter table is none and will not load for table: {table_name}" + ) + continue + num_counter_key_files = len(counter_key_files) + for i in range(num_counter_key_files): + counter_table.load(counter_key_files[i], counter_frequency_files[i]) + + def export_keys_values( + self, table_name: str, device: torch.device, batch_size: int = 65536 + ) -> Tuple[torch.Tensor, torch.Tensor]: + from dynamicemb.key_value_table import batched_export_keys_values + + keys_list = [] + values_list = [] + self.flush() + for dynamic_table_name, dynamic_table in zip(self.table_names, self.tables): + assert isinstance( + dynamic_table, KeyValueTable + ), "Only KeyValueTable is supported for batched export keys and values" + if table_name != dynamic_table_name: + continue + + local_max_rows = dynamic_table.size() + accumulated_counts = 0 + + for keys, embeddings, _, _ in batched_export_keys_values( + dynamic_table.table, device, batch_size + ): + keys_list.append(keys) + values_list.append(embeddings) + accumulated_counts += keys.numel() + + if local_max_rows != accumulated_counts: + raise ValueError( + f"Rank {dist.get_rank()} has accumulated count {accumulated_counts} which is different from expected {local_max_rows}, " + f"difference: {accumulated_counts - local_max_rows}" + ) + return torch.cat(keys_list), torch.cat(values_list, dim=0) + + def incremental_dump( + self, + named_thresholds: Dict[str, int] = None, + pg: Optional[dist.ProcessGroup] = None, + ) -> Tuple[Dict[str, Tuple[Tensor, Tensor]], Dict[str, int]]: + table_names: List[str] = named_thresholds.keys() + table_thresholds: List[int] = named_thresholds.values() + ret_tensors: Dict[str, Tuple[Tensor, Tensor]] = {} + ret_scores: Dict[str, int] = {} + + def _export_matched_per_table(pg, table, threshold): + if not dist.is_initialized() or dist.get_world_size(group=pg) == 1: + key, value = _export_matched(table, threshold) + else: + key, value = _export_matched_and_gather(table, threshold, pg) + return key, value + + for table_name, threshold in zip(table_names, table_thresholds): + index = self._table_names.index(table_name) + + storage = self._storages[index] + if not isinstance(storage, KeyValueTable): + raise RuntimeError( + "Only KeyValueTable is supported for incremental dump" + ) + key, value = _export_matched_per_table(pg, storage, threshold) + if self._caches[index] is not None: + # flush will change the score(timestamp) in storage + # self._caches[index].flush(self._storages[index]) + cache = self._caches[index] + key_c, value_c = _export_matched_per_table(pg, cache, threshold) + mask = ~torch.isin(key, key_c) + if key.numel() != 0: + if mask.sum() != 0: + value = ( + value.view(key.numel(), -1)[mask, :].contiguous().view(-1) + ) + key = key[mask].contiguous() + key = torch.cat((key_c, key), dim=0).contiguous() + value = torch.cat((value_c, value), dim=0).contiguous() + else: + key = key_c + value = value_c + else: + key = key_c + value = value_c + + ret_tensors[table_name] = (key, value) + ret_scores[table_name] = self._scores[table_name] + return ret_tensors, ret_scores + + +class BatchedDynamicEmbeddingTablesV2(nn.Module): + """ + Dynamic Embedding is based on [HKV](https://github.com/NVIDIA-Merlin/HierarchicalKV/tree/master). + Looks up one or more dynamic embedding tables. The module is application for training. + + Its optional to fuse the optimizer with the backward operator by parameter *update_grads_explicitly*. + """ + + optimizer_args: OptimizerArgs + + def __init__( + self, + table_options: List[DynamicEmbTableOptions], + table_names: Optional[List[str]] = None, + feature_table_map: Optional[List[int]] = None, # [T] + use_index_dedup: bool = False, + enable_prefetch: bool = False, + pooling_mode: DynamicEmbPoolingMode = DynamicEmbPoolingMode.SUM, + output_dtype: torch.dtype = torch.float32, + device: torch.device = None, + ext_ps: Optional[Storage] = None, + enforce_hbm: bool = False, # place all weights/momentums in HBM when using cache + bounds_check_mode: BoundsCheckMode = BoundsCheckMode.WARNING, + optimizer: EmbOptimType = EmbOptimType.SGD, + # General Optimizer args + stochastic_rounding: bool = True, + gradient_clipping: bool = False, + max_gradient: float = 1.0, + max_norm: float = 0.0, + learning_rate: float = 0.01, + # used by EXACT_ADAGRAD, EXACT_ROWWISE_ADAGRAD, EXACT_ROWWISE_WEIGHTED_ADAGRAD, LAMB, and ADAM only + # NOTE that default is different from nn.optim.Adagrad default of 1e-10 + eps: float = 1.0e-8, + # used by EXACT_ADAGRAD, EXACT_ROWWISE_ADAGRAD, and EXACT_ROWWISE_WEIGHTED_ADAGRAD only + initial_accumulator_value: float = 0.0, + momentum: float = 0.9, # used by LARS-SGD + # EXACT_ADAGRAD, SGD, EXACT_SGD do not support weight decay + # LAMB, ADAM, PARTIAL_ROWWISE_ADAM, PARTIAL_ROWWISE_LAMB, LARS_SGD support decoupled weight decay + # EXACT_ROWWISE_WEIGHTED_ADAGRAD supports L2 weight decay + # EXACT_ROWWISE_ADAGRAD support both L2 and decoupled weight decay (via weight_decay_mode) + weight_decay: float = 0.0, + weight_decay_mode: WeightDecayMode = WeightDecayMode.NONE, + eta: float = 0.001, # used by LARS-SGD, + beta1: float = 0.9, # used by LAMB and ADAM + beta2: float = 0.999, # used by LAMB and ADAM + counter_based_regularization: Optional[ + CounterBasedRegularizationDefinition + ] = None, # used by Rowwise Adagrad + cowclip_regularization: Optional[ + CowClipDefinition + ] = None, # used by Rowwise Adagrad + # TO align with FBGEMM TBE + *args, + **kwargs, + ) -> None: + super().__init__() + assert len(table_options) >= 1 + table_option = table_options[0] + for other_option in table_options: + assert ( + table_option == other_option + ), "All tables must match in grouped keys." + self._dynamicemb_options = table_options + self.initializer_args = table_option.initializer_args + self.index_type = table_option.index_type + self.embedding_dtype = table_option.embedding_dtype + self.output_dtype = output_dtype + self.pooling_mode = pooling_mode + self.use_index_dedup = use_index_dedup + self._enable_prefetch = enable_prefetch + self.prefetch_stream = None + self.num_prefetch_ahead = 0 + self._table_names = table_names + self.bounds_check_mode_int: int = bounds_check_mode.value + self._create_score() + + if device is not None: + self.device_id = int(str(device)[-1]) + else: + assert torch.cuda.is_available(), "No available CUDA device." + self.device_id = torch.cuda.current_device() + + if table_option.device_id is None: + for option in self._dynamicemb_options: + option.device_id = self.device_id + # get cuda device config + device_properties = torch.cuda.get_device_properties(self.device_id) + self._device_num_sms = device_properties.multi_processor_count + + self.dims: List[int] = [option.dim for option in self._dynamicemb_options] + # mixed D is not supported by sequence embedding. + mixed_D = False + D = self.dims[0] + for d in self.dims: + if d != D: + mixed_D = True + break + if mixed_D: + assert ( + self.pooling_mode != DynamicEmbPoolingMode.NONE + ), "Mixed dimension tables only supported for pooling tables." + + # physical table number. + T_ = len(self._dynamicemb_options) + assert T_ > 0 + self.feature_table_map: List[int] = ( + feature_table_map if feature_table_map is not None else list(range(T_)) + ) + # logical table number. + T = len(self.feature_table_map) + assert T_ <= T + table_has_feature = [False] * T_ + for t in self.feature_table_map: + table_has_feature[t] = True + assert all(table_has_feature), "Each table must have at least one feature!" + + feature_dims = [self.dims[t] for t in self.feature_table_map] + D_offsets = [0] + list(accumulate(feature_dims)) + self.total_D: int = D_offsets[-1] + self.max_D: int = max(self.dims) + + self.feature_num = len(self.feature_table_map) + # TODO:deal with shuffeld feature_table_map + self.table_offsets_in_feature: List[int] = [] + old_table_id = -1 + for idx, table_id in enumerate(self.feature_table_map): + if table_id != old_table_id: + self.table_offsets_in_feature.append(idx) + old_table_id = table_id + self.table_offsets_in_feature.append(self.feature_num) + self.feature_offsets = torch.tensor( + self.table_offsets_in_feature, + device=torch.device(self.device_id), + dtype=torch.int64, + ) + + for option in self._dynamicemb_options: + if option.init_capacity is None: + option.init_capacity = option.max_capacity + + if self.pooling_mode != DynamicEmbPoolingMode.NONE: + self._optimizer_type = optimizer + self._create_tables() + + self._optimizer: Union[ + BaseDynamicEmbeddingOptimizer, BaseDynamicEmbeddingOptimizerV2 + ] = None + self._create_optimizer( + optimizer, + stochastic_rounding, + gradient_clipping, + max_gradient, + max_norm, + learning_rate, + eps, + initial_accumulator_value, + beta1, + beta2, + weight_decay, + eta, + momentum, + weight_decay_mode, + counter_based_regularization, + cowclip_regularization, + ) + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + self._create_cache_storage(ext_ps) + self._initializers = [] + self._eval_initializers = [] + self._create_initializers() + + # TODO:1->10 + self._empty_tensor = nn.Parameter( + torch.empty( + 10, + requires_grad=True, + device=torch.device(self.device_id), + dtype=self.embedding_dtype, + ) + ) + + # new a unique op + # TODO: in our case maybe we can use torch.uint32 + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + count_dtype = torch.long + else: + count_dtype = torch.uint64 + reserve_keys = torch.tensor( + 2, dtype=self.index_type, device=torch.device(self.device_id) + ) + reserve_vals = torch.tensor( + 2, dtype=count_dtype, device=torch.device(self.device_id) + ) + counter = torch.tensor( + 1, dtype=count_dtype, device=torch.device(self.device_id) + ) + self._unique_op = UniqueOp(reserve_keys, reserve_vals, counter, 2) + + def _create_tables(self) -> None: + self._tables: List[DynamicEmbTable] = [] + for option in self._dynamicemb_options: + if option.training: + if self._optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + option.optimizer_type = OptimizerType.RowWiseAdaGrad + elif ( + self._optimizer_type == EmbOptimType.SGD + or self._optimizer_type == EmbOptimType.EXACT_SGD + ): + option.optimizer_type = OptimizerType.SGD + elif self._optimizer_type == EmbOptimType.ADAM: + option.optimizer_type = OptimizerType.Adam + elif self._optimizer_type == EmbOptimType.EXACT_ADAGRAD: + option.optimizer_type = OptimizerType.AdaGrad + else: + raise ValueError( + f"Not supported optimizer type ,optimizer type = {self._optimizer_type} {type(self._optimizer_type)} {self._optimizer_type.value}." + ) + self._tables.append(create_dynamicemb_table(option)) + + def _create_cache_storage(self, PS: Storage = None) -> None: + self._storages: List[Storage] = [] + self._caches: List[Cache] = [] + self._caching = self._dynamicemb_options[0].caching + + for option in self._dynamicemb_options: + if option.training: + if self._optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + option.optimizer_type = OptimizerType.RowWiseAdaGrad + elif ( + self._optimizer_type == EmbOptimType.SGD + or self._optimizer_type == EmbOptimType.EXACT_SGD + ): + option.optimizer_type = OptimizerType.SGD + elif self._optimizer_type == EmbOptimType.ADAM: + option.optimizer_type = OptimizerType.Adam + elif self._optimizer_type == EmbOptimType.EXACT_ADAGRAD: + option.optimizer_type = OptimizerType.AdaGrad + else: + raise ValueError( + f"Not supported optimizer type ,optimizer type = {self._optimizer_type} {type(self._optimizer_type)} {self._optimizer_type.value}." + ) + if option.caching and option.training: + cache_option = deepcopy(option) + cache_option.bucket_capacity = 1024 + capacity = get_constraint_capacity( + option.local_hbm_for_values, + option.embedding_dtype, + option.dim, + option.optimizer_type, + cache_option.bucket_capacity, + ) + if capacity == 0: + raise ValueError( + "Can't use caching mode as the reserved HBM size is too small." + ) + + cache_option.max_capacity = capacity + cache_option.init_capacity = capacity + self._caches.append(KeyValueTable(cache_option, self._optimizer)) + + storage_option = deepcopy(option) + storage_option.local_hbm_for_values = 0 + self._storages.append( + PS(storage_option, self._optimizer) + if PS + else KeyValueTable(storage_option, self._optimizer) + ) + else: + self._caches.append(None) + self._storages.append(KeyValueTable(option, self._optimizer)) + + _print_memory_consume( + self._table_names, self._dynamicemb_options, self._optimizer, self.device_id + ) + + def _create_initializers(self) -> None: + def _get_initializer(initializer_args): + mode = initializer_args.mode + if mode == DynamicEmbInitializerMode.NORMAL: + initializer = NormalInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.TRUNCATED_NORMAL: + initializer = TruncatedNormalInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.UNIFORM: + initializer = UniformInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.CONSTANT: + initializer = ConstantInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.DEBUG: + initializer = DebugInitializer(initializer_args) + else: + raise ValueError( + f"Not supported initializer type({mode}) {type(mode)} {mode.value}." + ) + return initializer + + for option in self._dynamicemb_options: + initializer = _get_initializer(option.initializer_args) + self._initializers.append(initializer) + eval_initializer = _get_initializer(option.eval_initializer_args) + self._eval_initializers.append(eval_initializer) + + def _create_table_optimizer( + self, + optimizer_type: EmbOptimType, + optimizer_args: OptimizerArgs, + ) -> None: + if optimizer_type == EmbOptimType.SGD: + self._optimizer = SGDDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + self._tables, + ) + elif optimizer_type == EmbOptimType.EXACT_SGD: + self._optimizer = SGDDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + self._tables, + ) + elif optimizer_type == EmbOptimType.ADAM: + self._optimizer = AdamDynamicEmbeddingOptimizer( + optimizer_args, + self._dynamicemb_options, + self._tables, + ) + elif optimizer_type == EmbOptimType.EXACT_ADAGRAD: + self._optimizer = AdaGradDynamicEmbeddingOptimizer( + optimizer_args, self._dynamicemb_options, self._tables + ) + elif optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + self._optimizer = RowWiseAdaGradDynamicEmbeddingOptimizer( + optimizer_args, self._dynamicemb_options, self._tables + ) + else: + raise ValueError( + f"Not supported optimizer type ,optimizer type = {optimizer_type} {type(optimizer_type)} {optimizer_type.value}." + ) + + def _create_optimizer( + self, + optimizer_type: EmbOptimType, + stochastic_rounding: bool, + gradient_clipping: bool, + max_gradient: float, + max_norm: float, + learning_rate: float, + eps: float, + initial_accumulator_value: float, + beta1: float, + beta2: float, + weight_decay: float, + eta: float, + momentum: float, + weight_decay_mode: WeightDecayMode, + counter_based_regularization: Optional[CounterBasedRegularizationDefinition], + cowclip_regularization: Optional[CowClipDefinition], + ) -> None: + self._optimizer_type = optimizer_type + self.stochastic_rounding = stochastic_rounding + + self.weight_decay_mode = weight_decay_mode + if (weight_decay_mode == WeightDecayMode.COUNTER) != ( + counter_based_regularization is not None + ): + raise AssertionError( + "Need to set weight_decay_mode=WeightDecayMode.COUNTER together with valid counter_based_regularization" + ) + if (weight_decay_mode == WeightDecayMode.COWCLIP) != ( + cowclip_regularization is not None + ): + raise AssertionError( + "Need to set weight_decay_mode=WeightDecayMode.COWCLIP together with valid cowclip_regularization" + ) + + self._used_rowwise_adagrad_with_counter: bool = ( + optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD + and ( + weight_decay_mode in (WeightDecayMode.COUNTER, WeightDecayMode.COWCLIP) + ) + ) + + if counter_based_regularization is None: + counter_based_regularization = CounterBasedRegularizationDefinition() + if cowclip_regularization is None: + cowclip_regularization = CowClipDefinition() + self._max_counter_update_freq: int = -1 + # Extract parameters from CounterBasedRegularizationDefinition or CowClipDefinition + # which are passed as entries for OptimizerArgs + if self._used_rowwise_adagrad_with_counter: + if self.weight_decay_mode == WeightDecayMode.COUNTER: + self._max_counter_update_freq = ( + counter_based_regularization.max_counter_update_freq + ) + opt_arg_weight_decay_mode = ( + counter_based_regularization.counter_weight_decay_mode + ) + counter_halflife = counter_based_regularization.counter_halflife + else: + opt_arg_weight_decay_mode = ( + cowclip_regularization.counter_weight_decay_mode + ) + counter_halflife = cowclip_regularization.counter_halflife + else: + opt_arg_weight_decay_mode = weight_decay_mode + # Default: -1, no decay applied, as a placeholder for OptimizerArgs + # which should not be effective when CounterBasedRegularizationDefinition + # and CowClipDefinition are not used + counter_halflife = -1 + + optimizer_args = OptimizerArgs( + stochastic_rounding=stochastic_rounding, + gradient_clipping=gradient_clipping, + max_gradient=max_gradient, + max_norm=max_norm, + learning_rate=learning_rate, + eps=eps, + initial_accumulator_value=initial_accumulator_value, + beta1=beta1, + beta2=beta2, + weight_decay=weight_decay, + weight_decay_mode=opt_arg_weight_decay_mode.value, + eta=eta, + momentum=momentum, + counter_halflife=counter_halflife, + adjustment_iter=counter_based_regularization.adjustment_iter, + adjustment_ub=counter_based_regularization.adjustment_ub, + learning_rate_mode=counter_based_regularization.learning_rate_mode.value, + grad_sum_decay=counter_based_regularization.grad_sum_decay.value, + tail_id_threshold=counter_based_regularization.tail_id_threshold.val, + is_tail_id_thresh_ratio=int( + counter_based_regularization.tail_id_threshold.is_ratio + ), + total_hash_size=0, + weight_norm_coefficient=cowclip_regularization.weight_norm_coefficient, + lower_bound=cowclip_regularization.lower_bound, + regularization_mode=weight_decay_mode.value, + ) + if self.pooling_mode != DynamicEmbPoolingMode.NONE: + self._create_table_optimizer(optimizer_type, optimizer_args) + return + + if optimizer_type == EmbOptimType.SGD: + self._optimizer = SGDDynamicEmbeddingOptimizerV2( + optimizer_args, + ) + elif optimizer_type == EmbOptimType.EXACT_SGD: + self._optimizer = SGDDynamicEmbeddingOptimizerV2( + optimizer_args, + ) + elif optimizer_type == EmbOptimType.ADAM: + self._optimizer = AdamDynamicEmbeddingOptimizerV2( + optimizer_args, + ) + elif optimizer_type == EmbOptimType.EXACT_ADAGRAD: + self._optimizer = AdaGradDynamicEmbeddingOptimizerV2( + optimizer_args, + ) + elif optimizer_type == EmbOptimType.EXACT_ROWWISE_ADAGRAD: + self._optimizer = RowWiseAdaGradDynamicEmbeddingOptimizerV2( + optimizer_args, + self.embedding_dtype, + ) + else: + raise ValueError( + f"Not supported optimizer type ,optimizer type = {optimizer_type} {type(optimizer_type)} {optimizer_type.value}." + ) + + def split_embedding_weights(self) -> List[Tensor]: + """ + Returns a list of weights, split by table + """ + splits = [] + for t, _ in enumerate(self._dynamicemb_options): + splits.append( + torch.empty( + (1, 1), device=torch.device("cuda"), dtype=self.embedding_dtype + ) + ) + return splits + + def flush(self) -> None: + self.num_prefetch_ahead = 0 + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + for cache, storage in zip(self._caches, self._storages): + cache.flush(storage) + + def reset_cache_states(self) -> None: + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + for cache in self._caches: + cache.reset() + + @property + def table_names(self) -> List[str]: + return self._table_names + + @property + def optimizer(self) -> BaseDynamicEmbeddingOptimizer: + return self._optimizer + + @property + def tables(self) -> List[DynamicEmbTable]: + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + return self._storages + else: + return self._tables + + @property + def caches(self) -> List[Cache]: + return self._caches + + def set_record_cache_metrics(self, record: bool) -> None: + for cache in self._caches: + cache.set_record_cache_metrics(record) + + def set_learning_rate(self, lr: float) -> None: + self._optimizer.set_learning_rate(lr) + return + + @property + def enable_prefetch( + self, + ) -> None: + return self._enable_prefetch + + @enable_prefetch.setter + def enable_prefetch(self, value: bool): + self._enable_prefetch = value + self.num_prefetch_ahead = 0 + + def forward( + self, + indices: Tensor, + offsets: Tensor, + per_sample_weights: Optional[Tensor] = None, + feature_requires_grad: Optional[Tensor] = None, + # 2D tensor of batch size for each rank and feature. + # Shape (number of features, number of ranks) + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + total_unique_indices: Optional[int] = None, + ) -> List[Tensor]: + if self._enable_prefetch: + self.num_prefetch_ahead -= 1 + + if indices.dtype != self.index_type: + indices = indices.to(self.index_type) + + if any([not o.training for o in self._dynamicemb_options]) and self.training: + raise RuntimeError( + "BatchedDynamicEmbeddingTables does not support training when some tables are in eval mode." + ) + + scores = [] + # if self.training: + for table_name in self._table_names: + if table_name not in self._scores.keys(): + raise RuntimeError( + f"Must set score for table '{table_name}' whose score_strategy is customized." + ) + scores.append(self._scores[table_name]) + + if self.pooling_mode == DynamicEmbPoolingMode.NONE: + for i, cache in enumerate(self._caches): + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = True + table.set_score(self._scores[self.table_names[i]]) + for i, storage in enumerate(self._storages): + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = True + table.set_score(self._scores[self.table_names[i]]) + res = DynamicEmbeddingFunctionV2.apply( + indices, + offsets, + self._caches, + self._storages, + self.feature_offsets, + self.output_dtype, + self._initializers if self.training else self._eval_initializers, + self._optimizer, + self._unique_op, + self._enable_prefetch, + self.use_index_dedup, + self.training, + per_sample_weights, # Pass frequency counters as weights + self._empty_tensor, + ) + for cache in self._caches: + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = False + for storage in self._storages: + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = False + else: + res = DynamicEmbeddingBagFunction.apply( + indices, + offsets, + self.use_index_dedup, + self.table_offsets_in_feature, + self._tables, + scores, + self.total_D, + self.dims, + self.feature_table_map, + self.embedding_dtype, + self.output_dtype, + self.pooling_mode, + self._device_num_sms, + self._unique_op, + torch.device(self.device_id), + self._optimizer, + self.training, + [option.eval_initializer_args for option in self._dynamicemb_options], + self._empty_tensor, + ) + + # We have to update cache's core in eval mode. + if self.training or self._caching: + self._update_score() + + return res + + def prefetch( + self, + indices: Tensor, + offsets: Tensor, + forward_stream: Optional[torch.cuda.Stream] = None, + batch_size_per_feature_per_rank: Optional[List[List[int]]] = None, + ) -> None: + assert ( + self.pooling_mode == DynamicEmbPoolingMode.NONE + ), "only support prefetch for sequence embedding." + + if self.prefetch_stream is None and forward_stream is not None: + # Set the prefetch stream to the current stream + self.prefetch_stream = torch.cuda.current_stream() + assert ( + self.prefetch_stream != forward_stream + ), "prefetch_stream and forward_stream should not be the same stream" + + current_stream = torch.cuda.current_stream() + # Record tensors on the current stream + indices.record_stream(current_stream) + offsets.record_stream(current_stream) + + if self._enable_prefetch: + self.num_prefetch_ahead += 1 + assert self.num_prefetch_ahead >= 1, "Prefetch context mismatches." + + prefetch_scores = self._get_prefetch_score() + + for i, cache in enumerate(self._caches): + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = True + table.set_score(prefetch_scores[i]) + for i, storage in enumerate(self._storages): + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = True + table.set_score(prefetch_scores[i]) + + dynamicemb_prefetch( + indices, + offsets, + self._caches, + self._storages, + self.feature_offsets, + self._initializers if self.training else self._eval_initializers, + self._unique_op, + self.training, + forward_stream, + ) + + for cache in self._caches: + if isinstance(cache, KeyValueTable): + table = cast(KeyValueTable, cache) + table.score_update = False + for storage in self._storages: + if isinstance(storage, KeyValueTable): + table = cast(KeyValueTable, storage) + table.score_update = False + + def set_score( + self, + named_score: Dict[str, int], + ) -> None: + table_names: List[str] = named_score.keys() + table_scores: List[int] = named_score.values() + for table_name, table_score in zip(table_names, table_scores): + if not isinstance(table_score, int): + raise ValueError( + f"Table's score is expect to int but got {type(table_score)}" + ) + if table_score == 0: + raise ValueError(f"Can't set table's score to 0.") + index = self._table_names.index(table_name) + assert ( + self._dynamicemb_options[index].score_strategy + == DynamicEmbScoreStrategy.CUSTOMIZED + ), "Can only set score for table whose score_strategy is DynamicEmbScoreStrategy.CUSTOMIZED." + + if table_name in self._scores and self._scores[table_name] > table_score: + if warning_for_cstm_score(): + warnings.warn( + f"New set score is less than the old one for table '{table_name}': {table_score} < {self._scores[table_name]}", + UserWarning, + ) + self._scores[table_name] = table_score + + def get_score(self) -> Dict[str, int]: + return self._scores.copy() + + def _create_score(self): + self._scores: Dict[str, int] = {} + for table_name, option in zip(self._table_names, self._dynamicemb_options): + if option.score_strategy == DynamicEmbScoreStrategy.TIMESTAMP: + option.evict_strategy = DynamicEmbEvictStrategy.LRU + self._scores[table_name] = device_timestamp() + elif option.score_strategy == DynamicEmbScoreStrategy.STEP: + option.evict_strategy = DynamicEmbEvictStrategy.CUSTOMIZED + self._scores[table_name] = 1 + elif option.score_strategy == DynamicEmbScoreStrategy.CUSTOMIZED: + option.evict_strategy = DynamicEmbEvictStrategy.CUSTOMIZED + elif option.score_strategy == DynamicEmbScoreStrategy.LFU: + option.evict_strategy = DynamicEmbEvictStrategy.LFU + self._scores[table_name] = 1 + + def _update_score(self): + for table_name, option in zip(self._table_names, self._dynamicemb_options): + old_score = self._scores[table_name] + if option.score_strategy == DynamicEmbScoreStrategy.TIMESTAMP: + new_score = device_timestamp() + if new_score < old_score: + warnings.warn( + f"Table '{table_name}' 's score({new_score}) is less than old one({old_score}).", + UserWarning, + ) + self._scores[table_name] = new_score + elif option.score_strategy == DynamicEmbScoreStrategy.STEP: + max_uint64 = (2**64) - 1 + new_score = old_score + 1 + if new_score > max_uint64: + warnings.warn( + f"Table '{table_name}' 's score({new_score}) is out of range, reset to 0.", + UserWarning, + ) + self._scores[table_name] = 0 + else: + self._scores[table_name] = new_score + elif option.score_strategy == DynamicEmbScoreStrategy.LFU: + self._scores[table_name] = 1 + + def _get_prefetch_score( + self, + ): + ret_scores = [] + for table_name, option in zip(self._table_names, self._dynamicemb_options): + cur_score = self._scores[table_name] + if ( + self.enable_prefetch + and option.score_strategy == DynamicEmbScoreStrategy.STEP + ): + max_uint64 = (2**64) - 1 + new_score = cur_score + self.num_prefetch_ahead - 1 + if new_score > max_uint64: + warnings.warn( + f"Table '{table_name}' 's score({new_score}) is out of range, reset to 0.", + UserWarning, + ) + new_score = 0 + else: + new_score = cur_score + + ret_scores.append(new_score) + return ret_scores + def incremental_dump( self, named_thresholds: Dict[str, int] = None, @@ -715,3 +2230,4 @@ def incremental_dump( ret_tensors[table_name] = (key, value) ret_scores[table_name] = self._scores[table_name] return ret_tensors, ret_scores +BatchedDynamicEmbeddingTables = BatchedDynamicEmbeddingTablesV2 diff --git a/corelib/dynamicemb/dynamicemb/dump_load.py b/corelib/dynamicemb/dynamicemb/dump_load.py index 39792aadc..51941f705 100644 --- a/corelib/dynamicemb/dynamicemb/dump_load.py +++ b/corelib/dynamicemb/dynamicemb/dump_load.py @@ -927,7 +927,9 @@ def DynamicEmbDump( model: nn.Module, table_names: Optional[Dict[str, List[str]]] = None, optim: Optional[bool] = False, - pg: Optional[dist.ProcessGroup] = None, + counter: Optional[bool] = False, + pg: dist.ProcessGroup = dist.group.WORLD, + allow_overwrite: bool = False, ) -> None: """ Dump the distributed weights and corresponding optimizer states of dynamic embedding tables from the model to the filesystem. @@ -948,6 +950,8 @@ def DynamicEmbDump( and the value is a list of dynamic embedding table names within that collection. Defaults to None. optim : Optional[bool], optional Whether to dump the optimizer states. Defaults to False. + counter : Optional[bool], optional + Whether to dump the embedding admission counter table. Defaults to False. pg : Optional[dist.ProcessGroup], optional The process group used to control the communication scope in the dump. Defaults to None. @@ -1071,62 +1075,20 @@ def DynamicEmbDump( for i, tmp_collection in enumerate(collections_list): collection_path, tmp_collection_name, tmp_collection_module = tmp_collection full_collection_path = os.path.join(path, collection_path) - tmp_dynamic_emb_module_list = get_dynamic_emb_module(tmp_collection_module) - - for j, dynamic_emb_module in enumerate(tmp_dynamic_emb_module_list): - tmp_table_names = dynamic_emb_module.table_names - tmp_tables = dynamic_emb_module.tables - - filtered_table_names: List[str] = [] - filtered_dynamic_tables: List[DynamicEmbTable] = [] - # TODO:need a warning - if table_names is not None: - tmp_input_names = table_names[tmp_collection_name] - for name in tmp_input_names: - if name in tmp_table_names: - index = tmp_table_names.index(name) - filtered_table_names.append(tmp_table_names[index]) - filtered_dynamic_tables.append(tmp_tables[index]) - else: - filtered_table_names = tmp_table_names - filtered_dynamic_tables = tmp_tables - if len(filtered_table_names) == 0: - continue - - if optim: - optimizer = dynamic_emb_module.optimizer - opt_args = optimizer.get_opt_args() - - tmp_tables_dict: Dict[str, DynamicEmbTable] = { - name: table - for name, table in zip(filtered_table_names, filtered_dynamic_tables) - } - - if rank == 0: - # Rank 0 determines the order of keys - ordered_keys = tmp_tables_dict.keys() - ordered_keys_str = ",".join(ordered_keys) - else: - ordered_keys_str = "" - - ordered_keys_str = broadcast_string(ordered_keys_str, rank=rank, pg=pg) - ordered_keys = ordered_keys_str.split(",") - - for k, dump_name in enumerate(ordered_keys): - dynamic_table = tmp_tables_dict[dump_name] - gather_and_export( - dynamic_table, full_collection_path, dump_name, pg=pg, optim=optim - ) - - if optim: - args_filename = dump_name + "_opt_args.json" - args_path = os.path.join(full_collection_path, args_filename) - save_to_json(opt_args, args_path) - if rank == 0: - print( - f"DynamicEmb dump table {dump_name} from module {tmp_collection_name} success!" - ) - + current_dynamic_emb_module_list = get_dynamic_emb_module( + tmp_collection_module + ) + table_names_to_dump = ( + table_names.get(collection_path, None) if table_names else None + ) + for dynamic_emb_module in current_dynamic_emb_module_list: + dynamic_emb_module.dump( + full_collection_path, + optim=optim, + counter=counter, + table_names=table_names_to_dump, + pg=pg, + ) if torch.cuda.is_available(): torch.cuda.synchronize() @@ -1150,7 +1112,8 @@ def DynamicEmbLoad( model: nn.Module, table_names: Optional[List[str]] = None, optim: bool = False, - pg: Optional[dist.ProcessGroup] = None, + counter: bool = False, + pg: dist.ProcessGroup = dist.group.WORLD, ): """ Load the distributed weights and corresponding optimizer states of dynamic embedding tables from the filesystem into the model. @@ -1169,6 +1132,8 @@ def DynamicEmbLoad( and the value is a list of dynamic embedding table names within that collection. Defaults to None. optim : bool, optional Whether to load the optimizer states. Defaults to False. + counter : bool, optional + Whether to load the embedding admission counter table. Defaults to False. pg : Optional[dist.ProcessGroup], optional The process group used to control the communication scope in the load. Defaults to None. diff --git a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py index 59101c08b..ad0dcbef0 100644 --- a/corelib/dynamicemb/dynamicemb/dynamicemb_config.py +++ b/corelib/dynamicemb/dynamicemb/dynamicemb_config.py @@ -17,14 +17,20 @@ import os from dataclasses import dataclass, field, fields from math import sqrt -from typing import Optional +from typing import Dict, Optional import torch +from dynamicemb.types import ( + AdmissionStrategy, + Counter, + DynamicEmbInitializerArgs, + DynamicEmbInitializerMode, + Storage, +) from dynamicemb_extensions import ( DynamicEmbDataType, DynamicEmbTable, EvictStrategy, - InitializerArgs, OptimizerType, ) from torchrec.modules.embedding_configs import BaseEmbeddingConfig @@ -45,91 +51,6 @@ def warning_for_cstm_score() -> None: DynamicEmbKernel = "DynamicEmb" -class DynamicEmbInitializerMode(enum.Enum): - """ - Enumeration for different modes of initializing dynamic embedding vector values. - - Attributes - ---------- - NORMAL : str - Normal Distribution. - UNIFORM : str - Uniform distribution of random values. - CONSTANT : str - All dynamic embedding vector values are a given constant. - DEBUG : str - Debug value generation mode for testing. - """ - - NORMAL = "normal" - TRUNCATED_NORMAL = "truncated_normal" - UNIFORM = "uniform" - CONSTANT = "constant" - DEBUG = "debug" - - -@dataclass -class DynamicEmbInitializerArgs: - """ - Arguments for initializing dynamic embedding vector values. - - Attributes - ---------- - mode : DynamicEmbInitializerMode - The mode of initialization, one of the DynamicEmbInitializerMode values. - mean : float, optional - The mean value for (truncated) normal distributions. Defaults to 0.0. - std_dev : float, optional - The standard deviation for (truncated) normal distributions. Defaults to 1.0. - lower : float, optional - The lower bound for uniform/truncated_normal distribution. Defaults to 0.0. - upper : float, optional - The upper bound for uniform/truncated_normal distribution. Defaults to 1.0. - value : float, optional - The constant value for constant initialization. Defaults to 0.0. - """ - - mode: DynamicEmbInitializerMode = DynamicEmbInitializerMode.UNIFORM - mean: float = 0.0 - std_dev: float = 1.0 - lower: float = None - upper: float = None - value: float = 0.0 - - def __eq__(self, other): - if not isinstance(other, DynamicEmbInitializerArgs): - return NotImplementedError - if self.mode == DynamicEmbInitializerMode.NORMAL: - return self.mean == other.mean and self.std_dev == other.std_dev - elif self.mode == DynamicEmbInitializerMode.TRUNCATED_NORMAL: - return ( - self.mean == other.mean - and self.std_dev == other.std_dev - and self.lower == other.lower - and self.upper == other.upper - ) - elif self.mode == DynamicEmbInitializerMode.UNIFORM: - return self.lower == other.lower and self.upper == other.upper - elif self.mode == DynamicEmbInitializerMode.CONSTANT: - return self.value == other.value - return True - - def __ne__(self, other): - if not isinstance(other, DynamicEmbInitializerArgs): - return NotImplementedError - return not (self == other) - - def as_ctype(self) -> InitializerArgs: - return InitializerArgs( - self.mode.value, - self.mean, - self.std_dev, - self.lower if self.lower else 0.0, - self.upper if self.upper else 1.0, - self.value, - ) - - @enum.unique class DynamicEmbCheckMode(enum.IntEnum): """ @@ -310,6 +231,9 @@ class DynamicEmbTableOptions(HKVConfig): initializer_args : DynamicEmbInitializerArgs Arguments for initializing dynamic embedding vector values. Default is uniform distribution, and absolute values of upper and lower bound are sqrt(1 / eb_config.num_embeddings). + eval_initializer_args: DynamicEmbInitializerArgs + The initializer args for evaluation mode. + Default is constant initialization with value 0.0. score_strategy(DynamicEmbScoreStrategy): The strategy to set the score for each indices in forward and backward per table. Default to DynamicEmbScoreStrategy.TIMESTAMP. @@ -318,9 +242,25 @@ class DynamicEmbTableOptions(HKVConfig): safe_check_mode : DynamicEmbCheckMode Should dynamic embedding table insert safe check be enabled? By default, it is disabled. Please refer to the API documentation for DynamicEmbCheckMode for more information. - training: bool - Flag to indicate dynamic embedding tables is working on training mode or evaluation mode, default to `True`. - + global_hbm_for_values : int + Total GPU memory allocated to store embedding + optimizer states, in bytes. Default is 0. + It has different meanings under `caching=True` and `caching=False`. + When `caching=False`, it decides how much GPU memory is in the total memory to store value in a single hybrid table. + When `caching=True`, it decides the table capacity of the GPU table. + external_storage: Storage + The external storage/ParamterServer which inherits the interface of Storage, and can be configured per table. + If not provided, will using KeyValueTable as the Storage. + index_type : Optional[torch.dtype], optional + Index type of sparse features, will be set to DEFAULT_INDEX_TYPE(torch.int64) by default. + admit_strategy : Optional[AdmissionStrategy], optional + Admission strategy for controlling which keys are allowed to enter the embedding table. + If provided, only keys that meet the strategy's criteria will be inserted into the table. + Keys that don't meet the criteria will still be initialized and used in the forward pass, + but won't be stored in the table. Default is None (all keys are admitted). + admission_counter : Optional[Counter], optional + Counter for tracking the number of keys that have been admitted to the embedding table. + If provided, the counter will be used to track the number of keys that have been admitted to the embedding table. + Default is None (no counter is used). Notes ----- For detailed descriptions and additional context on each parameter, please refer to the documentation at @@ -330,8 +270,38 @@ class DynamicEmbTableOptions(HKVConfig): initializer_args: DynamicEmbInitializerArgs = field( default_factory=DynamicEmbInitializerArgs ) + eval_initializer_args: DynamicEmbInitializerArgs = field( + default_factory=lambda: DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.CONSTANT, + value=0.0, + ) + ) score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.TIMESTAMP - training: bool = True + bucket_capacity: int = 128 + safe_check_mode: DynamicEmbCheckMode = DynamicEmbCheckMode.IGNORE + global_hbm_for_values: int = 0 # in bytes + external_storage: Storage = None + index_type: Optional[torch.dtype] = None + admit_strategy: Optional[AdmissionStrategy] = None + admission_counter: Optional[Counter] = None + + def __post_init__(self): + assert ( + self.eval_initializer_args.mode == DynamicEmbInitializerMode.CONSTANT + ), "eval_initializer_args must be constant initialization" + + if self.init_capacity is not None: + target_init_capacity = _next_power_of_2(self.init_capacity) + if self.init_capacity != target_init_capacity: + warnings.warn( + f"init_capacity is changed to {target_init_capacity} from {self.init_capacity}" + ) + self.init_capacity = target_init_capacity + + def __post_init__(self): + assert ( + self.eval_initializer_args.mode == DynamicEmbInitializerMode.CONSTANT + ), "eval_initializer_args must be constant initialization" def __eq__(self, other): if not isinstance(other, DynamicEmbTableOptions): @@ -348,6 +318,11 @@ def __ne__(self, other): def get_grouped_key(self): grouped_key = {f.name: getattr(self, f.name) for f in fields(GroupedHKVConfig)} grouped_key["training"] = self.training + grouped_key["caching"] = self.caching + grouped_key["external_storage"] = self.external_storage + grouped_key["index_type"] = self.index_type + grouped_key["score_strategy"] = self.score_strategy + grouped_key["admit_strategy"] = self.admit_strategy return grouped_key def __hash__(self): @@ -495,6 +470,7 @@ def create_dynamicemb_table(table_options: DynamicEmbTableOptions) -> DynamicEmb ) +# TODO: sync with table def validate_initializer_args( initializer_args: DynamicEmbInitializerArgs, eb_config: BaseEmbeddingConfig = None ) -> None: @@ -505,3 +481,33 @@ def validate_initializer_args( initializer_args.lower = default_lower if initializer_args.upper is None: initializer_args.upper = default_upper + + +def get_optimizer_state_dim(optimizer_type, dim, dtype): + DTYPE_NUM_BYTES: Dict[torch.dtype, int] = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + } + if optimizer_type == OptimizerType.RowWiseAdaGrad: + return 16 // DTYPE_NUM_BYTES[dtype] + elif optimizer_type == OptimizerType.Adam: + return dim * 2 + elif optimizer_type == OptimizerType.AdaGrad: + return dim + else: + return 0 + + +def get_constraint_capacity( + memory_bytes, + dtype, + dim, + optimizer_type, + bucket_capacity, +) -> int: + byte_consume = ( + dim + get_optimizer_state_dim(optimizer_type, dim, dtype) + ) * dtype_to_bytes(dtype) + capacity = memory_bytes // byte_consume + return (capacity // bucket_capacity) * bucket_capacity diff --git a/corelib/dynamicemb/dynamicemb/embedding_admission.py b/corelib/dynamicemb/dynamicemb/embedding_admission.py new file mode 100644 index 000000000..86714c150 --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/embedding_admission.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import torch +from dynamicemb.initializer import create_initializer_from_args +from dynamicemb.scored_hashtable import ( + ScoreArg, + ScorePolicy, + ScoreSpec, + get_scored_table, +) +from dynamicemb.types import ( + AdmissionStrategy, + Counter, + DynamicEmbInitializerArgs, + MemoryType, +) + + +class KVCounter(Counter): + """ + Interface of a counter table which maps a key to a counter. + """ + + def __init__( + self, + capacity: int, + bucket_capacity: Optional[int] = 128, + key_type: Optional[torch.dtype] = torch.int64, + device: torch.device = None, + ): + self.score_name_ = "counter" + self.score_specs_ = [ + ScoreSpec(name=self.score_name_, policy=ScorePolicy.ACCUMULATE) + ] + self.score_args_ = [ScoreArg(name=self.score_name_, is_return=True)] + + self.table_ = get_scored_table( + capacity, bucket_capacity, key_type, self.score_specs_, device + ) + + def add( + self, keys: torch.Tensor, frequencies: torch.Tensor, inplace: bool + ) -> torch.Tensor: + """ + Add keys with frequencies to the `Counter` and get accumulated counter of each key. + For not existed keys, the frequencies will be assigned directly. + For existing keys, the frequencies will be accumulated. + + Args: + keys (torch.Tensor): The input keys, should be unique keys. + frequencies (torch.Tensor): The input frequencies, serve as initial or incremental values of frequencies' states. + inplace: If true then store the accumulated_frequencies to counter. + + Returns: + accumulated_frequencies (torch.Tensor): the frequencies' state in the `Counter` for the input keys. + """ + assert inplace == True, "Only support inplace=True" + self.score_args_[0].value = frequencies + + self.table_.insert(keys, self.score_args_) + return frequencies + + def erase(self, keys) -> None: + """ + Erase keys form the `Counter`. + + Args: + keys (torch.Tensor): The input keys to be erased. + """ + self.table_.erase(keys) + + def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + + Args: + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + return self.table_.memory_usage(mem_type) + + def load(self, key_file, counter_file) -> None: + """ + Load keys and frequencies from input file path. + + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ + self.table_.load(key_file, {self.score_name_: counter_file}) + + def dump(self, key_file, counter_file) -> None: + """ + Dump keys and frequencies to output file path. + + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ + self.table_.dump(key_file, {self.score_name_: counter_file}) + + +class FrequencyAdmissionStrategy(AdmissionStrategy): + """ + Frequency-based admission strategy. + Only admits keys whose frequency (score) meets or exceeds a threshold. + + Parameters + ---------- + threshold : int + Minimum frequency threshold for admission. Keys with frequency >= threshold + will be admitted into the embedding table. + initializer_args: Optional[DynamicEmbInitializerArgs] + Initializer arguments which determine how to initialize the embedding if the key is not admitted. + """ + + def __init__( + self, + threshold: int, + initializer_args: Optional[DynamicEmbInitializerArgs] = None, + ): + if threshold < 0: + raise ValueError(f"Threshold must be non-negative, got {threshold}") + + self.threshold = threshold + self.initializer_args = initializer_args + + def admit( + self, + keys: torch.Tensor, + frequencies: torch.Tensor, + ) -> torch.Tensor: + """ + Admit keys with frequencies >= threshold. + + Parameters + ---------- + keys : torch.Tensor + Keys to evaluate (shape: [N]) + frequencies : torch.Tensor + Frequency counts for each key (shape: [N]) + + Returns + ------- + torch.Tensor + Boolean mask (shape: [N]) where True indicates admission + """ + if keys.shape[0] != frequencies.shape[0]: + raise ValueError( + f"Keys and frequencies must have same length, got {keys.shape[0]} and {frequencies.shape[0]}" + ) + + # Admit keys whose frequency meets or exceeds threshold + admit_mask = frequencies >= self.threshold + return admit_mask + + def initialize_non_admitted_embeddings( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + ) -> bool: + """ + Initialize the embeddings for the keys that are not admitted. + + Returns: + bool: True if the embeddings are initialized, False otherwise. + """ + if self.initializer_args is None: + return False + non_admit_initializer = create_initializer_from_args(self.initializer_args) + non_admit_initializer( + buffer, + indices, + None, + ) + return True diff --git a/corelib/dynamicemb/dynamicemb/initializer.py b/corelib/dynamicemb/dynamicemb/initializer.py new file mode 100644 index 000000000..46a33e0d0 --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/initializer.py @@ -0,0 +1,130 @@ +import abc + +from dynamicemb.dynamicemb_config import * +from dynamicemb_extensions import ( + CurandStateContext, + const_init, + debug_init, + normal_init, + truncated_normal_init, + uniform_init, +) + + +class BaseDynamicEmbInitializer(abc.ABC): + def __init__(self, args: DynamicEmbInitializerArgs): + self._args = args + if self._args.lower is None: + self._args.lower = 0.0 + if self._args.upper is None: + self._args.upper = 1.0 + + @abc.abstractmethod + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + ... + + +class NormalInitializer(BaseDynamicEmbInitializer): + def __init__(self, args: DynamicEmbInitializerArgs): + super().__init__(args) + self._curand_state = CurandStateContext() + + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + normal_init( + buffer, indices, self._curand_state, self._args.mean, self._args.std_dev + ) + + +class TruncatedNormalInitializer(BaseDynamicEmbInitializer): + def __init__(self, args: DynamicEmbInitializerArgs): + super().__init__(args) + self._curand_state = CurandStateContext() + + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + truncated_normal_init( + buffer, + indices, + self._curand_state, + self._args.mean, + self._args.std_dev, + self._args.lower, + self._args.upper, + ) + + +class UniformInitializer(BaseDynamicEmbInitializer): + def __init__(self, args: DynamicEmbInitializerArgs): + super().__init__(args) + self._curand_state = CurandStateContext() + + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + uniform_init( + buffer, indices, self._curand_state, self._args.lower, self._args.upper + ) + + +class ConstantInitializer(BaseDynamicEmbInitializer): + def __init__(self, args: DynamicEmbInitializerArgs): + super().__init__(args) + + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + const_init(buffer, indices, self._args.value) + + +class DebugInitializer(BaseDynamicEmbInitializer): + def __init__(self, args: DynamicEmbInitializerArgs): + super().__init__(args) + + def __call__( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + keys: Optional[torch.Tensor], # remove it when debug mode is removed + ) -> None: + debug_init(buffer, indices, keys) + + +def create_initializer_from_args( + initializer_args: DynamicEmbInitializerArgs, +) -> BaseDynamicEmbInitializer: + """ + Factory function to create an initializer instance from initializer arguments. + """ + mode = initializer_args.mode + if mode == DynamicEmbInitializerMode.NORMAL: + return NormalInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.TRUNCATED_NORMAL: + return TruncatedNormalInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.UNIFORM: + return UniformInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.CONSTANT: + return ConstantInitializer(initializer_args) + elif mode == DynamicEmbInitializerMode.DEBUG: + return DebugInitializer(initializer_args) + else: + raise ValueError(f"Not supported initializer type: {mode}") diff --git a/corelib/dynamicemb/dynamicemb/input_dist.py b/corelib/dynamicemb/dynamicemb/input_dist.py index acf98e263..cf4eb659d 100644 --- a/corelib/dynamicemb/dynamicemb/input_dist.py +++ b/corelib/dynamicemb/dynamicemb/input_dist.py @@ -154,7 +154,9 @@ def bucketize_kjt_before_all2all( # duplicate keys will be resolved by AllToAll keys=_fx_wrap_gen_list_n_times(kjt.keys(), num_buckets), values=bucketized_indices, - weights=pos if bucketize_pos else bucketized_weights, + weights=_determine_output_weights( + kjt, pos, bucketize_pos, bucketized_weights + ), lengths=bucketized_lengths.view(-1), offsets=None, stride=_fx_wrap_stride(kjt), @@ -167,6 +169,31 @@ def bucketize_kjt_before_all2all( ) +def _determine_output_weights(kjt, pos, bucketize_pos, bucketized_weights): + """ + Determine which weights to return: pos or bucketized_weights. + + If the input weights appear to be frequency counters (float values that were + converted from uint64), preserve them instead of overriding with pos. + """ + + if not bucketize_pos: + return bucketized_weights + + if kjt.weights_or_none() is not None and bucketized_weights is not None: + # weights_as_int = bucketized_weights.long() + # weights_back_to_float = weights_as_int.float() + + # if torch.allclose( + # bucketized_weights, weights_back_to_float, atol=1e-6 + # ) and torch.all(bucketized_weights >= 0): + return bucketized_weights + + result = pos if bucketize_pos else bucketized_weights + + return result + + class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]): """ Bucketizes sparse features in RW fashion and then redistributes with an AlltoAll diff --git a/corelib/dynamicemb/dynamicemb/key_value_table.py b/corelib/dynamicemb/dynamicemb/key_value_table.py new file mode 100644 index 000000000..e7383b2f7 --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/key_value_table.py @@ -0,0 +1,1066 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Callable, Optional, Tuple + +import torch +from dynamicemb.dynamicemb_config import ( + DynamicEmbTableOptions, + create_dynamicemb_table, + torch_to_dyn_emb, +) +from dynamicemb.initializer import BaseDynamicEmbInitializer +from dynamicemb.optimizer import BaseDynamicEmbeddingOptimizerV2 +from dynamicemb.types import ( + EMBEDDING_TYPE, + KEY_TYPE, + OPT_STATE_TYPE, + SCORE_TYPE, + AdmissionStrategy, + Cache, + Counter, + Storage, + torch_dtype_to_np_dtype, +) +from dynamicemb_extensions import ( + EvictStrategy, + clear, + count_matched, + dyn_emb_capacity, + dyn_emb_cols, + dyn_emb_rows, + erase, + export_batch, + find_pointers, + find_pointers_with_scores, + insert_and_evict, + insert_and_evict_with_scores, + insert_or_assign, + load_from_pointers, + select, + select_index, +) + + +class Storage(abc.ABC): + @abc.abstractmethod + def __init__( + self, + options: DynamicEmbTableOptions, + optimizer: BaseDynamicEmbeddingOptimizerV2, + ): + pass + + @abc.abstractmethod + def find( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def insert( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: Optional[torch.Tensor] = None, + ) -> None: + pass + + @abc.abstractmethod + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def enable_update(self) -> bool: + ... + + @abc.abstractmethod + def dump( + self, + start: int, + end: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + num_dumped: torch.Tensor + dumped_keys: torch.Tensor + dumped_values: torch.Tensor + dumped_scores: torch.Tensor + return num_dumped, dumped_keys, dumped_values, dumped_scores + + @abc.abstractmethod + def load( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: torch.Tensor, + ) -> None: + pass + + @abc.abstractmethod + def embedding_dtype( + self, + ) -> torch.dtype: + pass + + @abc.abstractmethod + def embedding_dim( + self, + ) -> int: + pass + + @abc.abstractmethod + def value_dim( + self, + ) -> int: + pass + + @abc.abstractmethod + def init_optimizer_state( + self, + ) -> float: + pass + + +class EventQueue: + """A simple queue of CUDA events for stream synchronization in async prefetch.""" + + def __init__(self): + self._events = [] + + def produce(self) -> torch.cuda.Event: + """Create and record a new CUDA event, add it to the queue, and return it.""" + event = torch.cuda.Event() + self._events.append(event) + return event + + def consume(self) -> "Optional[torch.cuda.Event]": + """Pop and return the oldest CUDA event from the queue, or None if empty.""" + if self._events: + return self._events.pop(0) + return None + + def clear(self) -> None: + """Clear all events from the queue.""" + self._events.clear() + + def __len__(self) -> int: + return len(self._events) + + +class Cache(abc.ABC): + @abc.abstractmethod + def find( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def insert_and_evict( + self, + keys: torch.Tensor, + values: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + num_evicted: torch.Tensor + evicted_keys: torch.Tensor + evicted_values: torch.Tensor + evicted_scores: torch.Tensor + return num_evicted, evicted_keys, evicted_values, evicted_scores + + @abc.abstractmethod + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def flush(self, storage: Storage) -> None: + pass + + @abc.abstractmethod + def reset( + self, + ) -> None: + pass + + @property + @abc.abstractmethod + def event_queue(self) -> EventQueue: + pass + + @abc.abstractmethod + def cache_metrics( + self, + ) -> torch.Tensor: + pass + + @abc.abstractmethod + def set_record_cache_metrics(self, record: bool) -> None: + pass + + +class KeyValueTable(Cache, Storage): + def __init__( + self, + options: DynamicEmbTableOptions, + optimizer: BaseDynamicEmbeddingOptimizerV2, + ): + self.options = options + self.table = create_dynamicemb_table(options) + self.capacity = options.max_capacity + self.optimizer = optimizer + self.score: int = None + self._score_update = False + self._emb_dim = self.options.dim + self._emb_dtype = self.options.embedding_dtype + self._de_emb_dtype = torch_to_dyn_emb(self._emb_dtype) + self._value_dim = self._emb_dim + optimizer.get_state_dim(self._emb_dim) + self._initial_optim_state = optimizer.get_initial_optim_states() + + device_idx = torch.cuda.current_device() + self.device = torch.device(f"cuda:{device_idx}") + props = torch.cuda.get_device_properties(device_idx) + self._threads_in_wave = ( + props.multi_processor_count * props.max_threads_per_multi_processor + ) + + self._event_queue = EventQueue() + self._cache_metrics = torch.zeros(10, dtype=torch.long, device="cpu") + self._record_cache_metrics = False + self._use_score = self.table.evict_strategy() != EvictStrategy.KLru + + def find( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + if unique_keys.dtype != self.key_type(): + unique_keys = unique_keys.to(self.key_type()) + + if unique_embs.dtype != self.value_type(): + raise RuntimeError( + "Embedding dtype not match {} != {}".format( + unique_embs.dtype, self.value_type() + ) + ) + + batch = unique_keys.size(0) + assert unique_embs.dim() == 2 + assert unique_embs.size(0) == batch + + load_dim = unique_embs.size(1) + + device = unique_keys.device + if founds is None: + founds = torch.empty(batch, dtype=torch.bool, device=device) + pointers = torch.empty(batch, dtype=torch.long, device=device) + + scores = self.create_scores(batch, device, input_scores) + + if self._score_update: + find_pointers_with_scores( + self.table, batch, unique_keys, pointers, founds, scores + ) + else: + find_pointers(self.table, batch, unique_keys, pointers, founds) + + self.value_dim() + + if load_dim != 0: + load_from_pointers(pointers, unique_embs) + + missing = torch.logical_not(founds) + num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + missing_keys: torch.Tensor = torch.empty_like(unique_keys) + missing_indices: torch.Tensor = torch.empty( + batch, dtype=torch.long, device=device + ) + select(missing, unique_keys, missing_keys, num_missing_0) + select_index(missing, missing_indices, num_missing_1) + + if self._record_cache_metrics: + self._cache_metrics[0] = batch + self._cache_metrics[1] = founds.sum().item() + + h_num_missing = num_missing_0.cpu().item() + + # Handle missing scores: return None if scores is None + if scores is not None: + missing_scores = scores[missing_indices[:h_num_missing]] + else: + missing_scores = None + + return ( + h_num_missing, + missing_keys[:h_num_missing], + missing_indices[:h_num_missing], + missing_scores, + ) + + def find_embeddings( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + # Check shape to prevent misuse of find_embeddings and find + if unique_embs.dim() == 2 and unique_embs.size(1) != self.embedding_dim(): + raise ValueError( + f"find_embeddings expects dim={self.embedding_dim()}, got {unique_embs.size(1)}. " + ) + return self.find_impl(unique_keys, unique_embs, founds, input_scores) + + def find_missed_keys( + self, + unique_keys: torch.Tensor, + founds: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + # dummy tensor + unique_embs = torch.empty( + unique_keys.numel(), 0, device=unique_keys.device, dtype=self._emb_dtype + ) + return self.find_impl(unique_keys, unique_embs, founds, None) + + def find( + self, + unique_keys: torch.Tensor, + unique_vals: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + # Check shape to prevent misuse of find_embeddings and find + if unique_vals.dim() == 2 and unique_vals.size(1) != self.value_dim(): + raise ValueError( + f"find expects dim={self.value_dim()}, got {unique_vals.size(1)}. " + ) + return self.find_impl(unique_keys, unique_vals, founds, input_scores) + + def create_scores( + self, + h_num_total: int, + device: torch.device, + lfu_accumulated_frequency: Optional[torch.Tensor] = None, + ) -> Optional[torch.Tensor]: + """Create scores tensor for lookup operation based on eviction strategy.""" + if ( + lfu_accumulated_frequency is not None + and self.evict_strategy() == EvictStrategy.KLfu + ): + return lfu_accumulated_frequency + elif self.evict_strategy() == EvictStrategy.KLfu: + scores = torch.ones(h_num_total, device=device, dtype=torch.long) + return scores + elif self.evict_strategy() == EvictStrategy.KCustomized: + scores = torch.empty(h_num_total, device=device, dtype=torch.long) + scores.fill_(self.score) + return scores + else: + return None + + def insert( + self, + unique_keys: torch.Tensor, + unique_values: torch.Tensor, + scores: Optional[torch.Tensor] = None, + ) -> None: + h_num_unique_keys = unique_keys.size(0) + if self._use_score: + if scores is None: + scores = torch.empty( + h_num_unique_keys, device=unique_keys.device, dtype=torch.uint64 + ) + scores.fill_(self.score) + else: + scores = None + + if self.evict_strategy() == EvictStrategy.KLfu: + erase(self.table, h_num_unique_keys, unique_keys) + + insert_or_assign( + self.table, h_num_unique_keys, unique_keys, unique_values, scores + ) + + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert self._score_update == False, "update is called only in backward." + + batch = keys.size(0) + + device = keys.device + founds = torch.empty(batch, dtype=torch.bool, device=device) + pointers = torch.empty(batch, dtype=torch.long, device=device) + find_pointers(self.table, batch, keys, pointers, founds) + + self.optimizer.fused_update_with_pointer(grads, pointers, self._de_emb_dtype) + + missing = torch.logical_not(founds) + num_missing_0: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + num_missing_1: torch.Tensor = torch.empty(1, dtype=torch.long, device=device) + missing_keys: torch.Tensor = torch.empty_like(keys) + missing_indices: torch.Tensor = torch.empty( + batch, dtype=torch.long, device=device + ) + select(missing, keys, missing_keys, num_missing_0) + select_index(missing, missing_indices, num_missing_1) + return num_missing_0, missing_keys, missing_indices + + def enable_update(self) -> bool: + return True + + def set_score( + self, + score: int, + ) -> None: + self.score = score + + @property + def score_update( + self, + ) -> None: + return self._score_update + + @score_update.setter + def score_update(self, value: bool): + self._score_update = value + + def dump( + self, + start: int, + end: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + batch = end - start + device = self.device + key_dtype = self.options.index_type + value_dtype = self._emb_dtype + dim: int = self._value_dim + + num_dumped: torch.Tensor = torch.zeros(1, dtype=torch.uint64, device=device) + dumped_keys: torch.Tensor = torch.empty(batch, dtype=key_dtype, device=device) + dumped_values: torch.Tensor = torch.empty( + batch, dim, dtype=value_dtype, device=device + ) + dumped_scores: torch.Tensor = torch.empty( + batch, dtype=torch.uint64, device=device + ) + + export_batch( + self.table, + batch, + start, + num_dumped, + dumped_keys, + dumped_values, + dumped_scores, + ) + + return num_dumped, dumped_keys, dumped_values, dumped_scores + + def load( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: torch.Tensor, + ) -> None: + self.insert(keys, values, scores) + + def embedding_dtype( + self, + ) -> torch.dtype: + return self._emb_dtype + + def value_dim( + self, + ) -> int: + return self._value_dim + + def embedding_dim( + self, + ) -> int: + return self._emb_dim + + def init_optimizer_state( + self, + ) -> float: + return self._initial_optim_state + + def insert_and_evict( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + batch = keys.numel() + num_evicted: torch.Tensor = torch.zeros(1, dtype=torch.long, device=keys.device) + evicted_keys: torch.Tensor = torch.empty_like(keys) + evicted_values: torch.Tensor = torch.empty_like(values) + evicted_scores: torch.Tensor = torch.empty( + batch, dtype=torch.uint64, device=keys.device + ) + if scores is not None: + insert_and_evict_with_scores( + self.table, + batch, + keys, + values, + evicted_keys, + evicted_values, + evicted_scores, + num_evicted, + scores=scores, # scores as keyword argument + ) + else: + # TODO: Fix prefetch issue when scores is not provided + insert_and_evict( + self.table, + batch, + keys, + values, + self.score if self._use_score else None, + evicted_keys, + evicted_values, + evicted_scores, + num_evicted, + ) + if self._record_cache_metrics: + self._cache_metrics[2] = batch + self._cache_metrics[3] = num_evicted.cpu().item() + return num_evicted, evicted_keys, evicted_values, evicted_scores + + def flush(self, storage: Storage) -> None: + batch_size = self._threads_in_wave + for start in range(0, self.capacity, batch_size): + end = min(start + batch_size, self.capacity) + num_dumped, dumped_keys, dumped_values, dumped_scores = self.dump( + start, end + ) + h_num_dumped = num_dumped.cpu().item() + dumped_keys = dumped_keys[:h_num_dumped] + dumped_values = dumped_values[:h_num_dumped, :] + dumped_scores = dumped_scores[:h_num_dumped] + storage.insert(dumped_keys, dumped_values, dumped_scores) + + def reset( + self, + ) -> None: + clear(self.table) + self._event_queue.clear() + + @property + def event_queue(self) -> EventQueue: + return self._event_queue + + @property + def cache_metrics(self) -> Optional[torch.Tensor]: + return self._cache_metrics if self._record_cache_metrics else None + + def set_record_cache_metrics(self, record: bool) -> None: + self._record_cache_metrics = record + return + + +def update_cache( + cache: Cache, + storage: Storage, + missing_keys: torch.Tensor, + missing_values: torch.Tensor, + missing_scores: Optional[torch.Tensor] = None, +): + # need to update score. + num_evicted, evicted_keys, evicted_values, evicted_scores = cache.insert_and_evict( + missing_keys, + missing_values, + missing_scores, + ) + + if num_evicted != 0: + storage.insert( + evicted_keys[:h_num_evicted], + evicted_values[:h_num_evicted, :], + evicted_scores[:h_num_evicted], + ) + + +def admission( + keys: torch.Tensor, + freqs: torch.Tensor, + admit_strategy: AdmissionStrategy, + admission_counter: Counter, +) -> torch.Tensor: + freq_for_missing_keys = admission_counter.add(keys, freqs, inplace=True) + admit_mask = admit_strategy.admit( + keys, + freq_for_missing_keys, + ) + admitted_keys = keys[admit_mask] + admission_counter.erase(admitted_keys) + + return admit_mask + + +class KeyValueTableFunction: + @staticmethod + def lookup( + storage: Storage, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + initializer: Callable, + training: bool, + evict_strategy: EvictStrategy, + accumulated_frequency: Optional[torch.Tensor] = None, + admit_strategy: Optional[AdmissionStrategy] = None, + admission_counter: Optional[Counter] = None, + ) -> None: + assert unique_keys.dim() == 1 + h_num_toatl = unique_keys.numel() + emb_dim = storage.embedding_dim() + emb_dtype = storage.embedding_dtype() + val_dim = storage.value_dim() + + is_lfu_enabled = evict_strategy == EvictStrategy.KLfu + + if h_num_toatl == 0: + return + + # 1. find in storage + founds = torch.empty(h_num_toatl, device=unique_keys.device, dtype=torch.bool) + ( + h_num_missing_in_storage, + missing_keys_in_storage, + missing_indices_in_storage, + missing_scores_in_storage, + ) = storage.find_embeddings( + unique_keys, + unique_embs, + founds=founds, + input_scores=accumulated_frequency if is_lfu_enabled else None, + ) + + if h_num_missing_in_storage == 0: + return + + # if training and admit_strategy is not None: + + admit_mask = None + indices_to_init = missing_indices_in_storage + if training and admit_strategy is not None: + # do admission first + if accumulated_frequency is not None: + counters_for_admission = accumulated_frequency[ + missing_indices_in_storage + ] + else: + counters_for_admission = torch.ones( + missing_keys_in_storage.shape[0], + dtype=torch.int64, + device=unique_keys.device, + ) + + admit_mask = admission( + missing_keys_in_storage, + counters_for_admission, + admit_strategy, + admission_counter, + ) + + non_admitted_mask = ~admit_mask + non_admitted_indices = missing_indices_in_storage[non_admitted_mask] + initiailized_non_admitted_indices = False + if non_admitted_indices.numel() > 0: + initiailized_non_admitted_indices = ( + admit_strategy.initialize_non_admitted_embeddings( + unique_embs[:, :emb_dim], + non_admitted_indices, + ) + ) + + # Only initialize admitted embeddings with the regular initializer + if not initiailized_non_admitted_indices: + indices_to_init = missing_indices_in_storage[admit_mask] + + # 2. initialize missing embeddings (admitted or all if no admission) + if indices_to_init.numel() > 0: + initializer( + unique_embs, + indices_to_init, + unique_keys, + ) + + if training: + # insert missing values + missing_values_in_storage = torch.empty( + h_num_missing_in_storage, + val_dim, + device=unique_keys.device, + dtype=emb_dtype, + ) + missing_values_in_storage[:, :emb_dim] = unique_embs[ + missing_indices_in_storage, : + ] + if val_dim != emb_dim: + missing_values_in_storage[ + :, emb_dim - val_dim : + ] = storage.init_optimizer_state() + keys_to_insert = missing_keys_in_storage + values_to_insert = missing_values_in_storage + scores_to_insert = missing_scores_in_storage + if training and admit_strategy is not None: + keys_to_insert = keys_to_insert[admit_mask] + values_to_insert = values_to_insert[admit_mask] + scores_to_insert = ( + scores_to_insert[admit_mask] + if scores_to_insert is not None + else None + ) + + # 3. insert missing values into table. + storage.insert( + keys_to_insert, + values_to_insert, + scores_to_insert, + ) + # ignore the storage missed in eval mode + + @staticmethod + def update( + storage: Storage, + unique_keys: torch.Tensor, + unique_grads: torch.Tensor, + optimizer: BaseDynamicEmbeddingOptimizerV2, + ): + if storage.enable_update(): + storage.update(unique_keys, unique_grads, return_missing=False) + return + + emb_dtype = storage.embedding_dtype() + val_dim = storage.value_dim() + h_num_toatl = unique_keys.numel() + unique_values = torch.empty( + h_num_toatl, val_dim, device=unique_keys.device, dtype=emb_dtype + ) + founds = torch.empty(h_num_toatl, device=unique_keys.device, dtype=torch.bool) + _, _, _, _ = storage.find(unique_keys, unique_values, founds=founds) + + keys_for_storage = unique_keys[founds].contiguous() + values_for_storage = unique_values[founds, :].contiguous() + grads_for_storage = unique_grads[founds, :].contiguous() + optimizer.fused_update( + grads_for_storage, + values_for_storage, + ) + + storage.insert(keys_for_storage, values_for_storage) + + return + + +class KeyValueTableCachingFunction: + @staticmethod + def lookup( + cache: Cache, # partial emb + optimizer state + storage: Storage, # full emb + optimizer state + unique_keys: torch.Tensor, # input + unique_embs: torch.Tensor, # output + initializer: Callable, + enable_prefetch: bool, + training: bool, + evict_strategy: EvictStrategy, + accumulated_frequency: Optional[torch.Tensor] = None, + admit_strategy: Optional[AdmissionStrategy] = None, + admission_counter: Optional[Counter] = None, + ) -> None: + assert unique_keys.dim() == 1 + h_num_toatl = unique_keys.numel() + emb_dim = storage.embedding_dim() + emb_dtype = storage.embedding_dtype() + val_dim = storage.value_dim() + caching = cache is not None + + is_lfu_enabled = evict_strategy == EvictStrategy.KLfu + + ( + h_num_keys_for_storage, + missing_keys, + missing_indices, + missing_scores, + ) = cache.find_embeddings( + unique_keys, + unique_embs, + input_scores=accumulated_frequency if is_lfu_enabled else None, + ) + if h_num_keys_for_storage == 0: + return + keys_for_storage = missing_keys + + scores_for_storage = missing_scores + + founds = torch.empty( + h_num_keys_for_storage, device=unique_keys.device, dtype=torch.bool + ) + + # 2. find in storage + if caching and not enable_prefetch: + storage_load_dim = val_dim + else: + storage_load_dim = emb_dim + values_for_storage = torch.empty( + h_num_keys_for_storage, + storage_load_dim, + device=unique_keys.device, + dtype=emb_dtype, + ) + founds = torch.empty( + h_num_keys_for_storage, device=unique_keys.device, dtype=torch.bool + ) + ( + num_missing_in_storage, + missing_keys_in_storage, + missing_indices_in_storage, + missing_scores_in_storage, + ) = storage.find( + keys_for_storage, + values_for_storage, + founds=founds, + input_scores=scores_for_storage, + ) + + admit_mask_for_missing_keys = None + indices_to_init = missing_indices_in_storage + if training and admit_strategy is not None: + # Get frequency counters for admission: + if accumulated_frequency is not None: + # missing_indices_in_storage is index in keys_for_storage, Need to convert to index in unique_keys via missing_indices + indices_in_unique_keys = missing_indices[missing_indices_in_storage] + counters_for_admission = accumulated_frequency[indices_in_unique_keys] + else: + counters_for_admission = torch.ones( + missing_keys_in_storage.shape[0], + dtype=torch.int64, + device=unique_keys.device, + ) + + admit_mask_for_missing_keys = admission( + missing_keys_in_storage, + counters_for_admission, + admit_strategy, + admission_counter, + ) + + non_admitted_mask = ~admit_mask_for_missing_keys + non_admitted_indices = missing_indices_in_storage[non_admitted_mask] + initiailized_non_admitted_indices = False + if non_admitted_indices.numel() > 0: + initiailized_non_admitted_indices = ( + admit_strategy.initialize_non_admitted_embeddings( + values_for_storage[:, :emb_dim], + non_admitted_indices, + ) + ) + + # Only initialize admitted embeddings with the regular initializer + if not initiailized_non_admitted_indices: + indices_to_init = missing_indices_in_storage[ + admit_mask_for_missing_keys + ] + + # 3. initialize missing embeddings (admitted or all if no admission) + if indices_to_init.numel() > 0: + initializer( + values_for_storage[:, :emb_dim], + indices_to_init, + keys_for_storage, + ) + + # 4. copy embeddings only + unique_embs[missing_indices, :] = values_for_storage[:, :emb_dim] + + if h_num_missing_in_storage == 0: + return + + keys_to_update = None + values_to_update = None + scores_to_update = None + + if training: + if emb_dim != val_dim: + values_for_storage[ + missing_indices_in_storage, emb_dim - val_dim : + ] = storage.init_optimizer_state() + # 5.Optional Admission part + keys_to_update = keys_for_storage + values_to_update = values_for_storage + scores_to_update = scores_for_storage + + if admit_strategy is not None: + # build mask: including storage hit keys + keys that are both miss and admitted + mask_to_cache = founds + admitted_indices = missing_indices_in_storage[ + admit_mask_for_missing_keys + ] + mask_to_cache[admitted_indices] = True + + keys_to_update = keys_for_storage[mask_to_cache] + values_to_update = values_for_storage[mask_to_cache] + scores_to_update = ( + scores_for_storage[mask_to_cache] + if scores_for_storage is not None + else None + ) + else: # only update those found in the storage to cache. + found_keys_in_storage = keys_for_storage[founds].contiguous() + found_values_in_storage = values_for_storage[founds, :].contiguous() + found_scores_in_storage = ( + scores_for_storage[founds].contiguous() + if scores_for_storage is not None + else None + ) + keys_to_update = found_keys_in_storage + values_to_update = found_values_in_storage + scores_to_update = found_scores_in_storage + + update_cache(cache, storage, keys_to_update, values_to_update, scores_to_update) + return + + @staticmethod + def update( + cache: Optional[Cache], + storage: Storage, + unique_keys: torch.Tensor, + unique_grads: torch.Tensor, + optimizer: BaseDynamicEmbeddingOptimizerV2, + enable_prefetch: bool, + ): + if cache is not None: + num_missing, missing_keys, missing_indices = cache.update( + unique_keys, unique_grads + ) + h_num_keys_for_storage = num_missing.cpu().item() + keys_for_storage = missing_keys[:h_num_keys_for_storage] + missing_indices = missing_indices[:h_num_keys_for_storage] + grads_for_storage = unique_grads[missing_indices, :].contiguous() + else: + keys_for_storage = unique_keys + grads_for_storage = unique_grads + + if storage.enable_update(): + storage.update(keys_for_storage, grads_for_storage) + return + + emb_dtype = storage.embedding_dtype() + val_dim = storage.value_dim() + storage.embedding_dim() + values_for_storage = torch.empty( + h_num_keys_for_storage, val_dim, device=unique_keys.device, dtype=emb_dtype + ) + founds = torch.empty( + h_num_keys_for_storage, device=unique_keys.device, dtype=torch.bool + ) + _, _, _, _ = storage.find(keys_for_storage, values_for_storage, founds=founds) + keys_for_storage = keys_for_storage[founds].contiguous() + values_for_storage = values_for_storage[founds, :].contiguous() + grads_for_storage = grads_for_storage[founds, :].contiguous() + optimizer.fused_update( + grads_for_storage, + values_for_storage, + ) + + storage.insert(keys_for_storage, values_for_storage) + return + + @staticmethod + def prefetch( + cache: Cache, + storage: Storage, + unique_keys: torch.Tensor, + initializer: BaseDynamicEmbInitializer, + training: bool = True, + forward_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + assert cache is not None + emb_dtype = storage.embedding_dtype() + h_num_keys_for_storage, missing_keys, _, _ = cache.find_missed_keys(unique_keys) + + h_num_keys_for_storage = num_missing.cpu().item() + missing_keys = missing_keys[:h_num_keys_for_storage] + if h_num_keys_for_storage == 0: + if forward_stream is not None: + cache.event_queue.produce().record() + return + + val_dim = storage.value_dim() + emb_dim = storage.embedding_dim() + values_for_storage = torch.empty( + h_num_keys_for_storage, val_dim, device=unique_keys.device, dtype=emb_dtype + ) + founds = torch.empty( + h_num_keys_for_storage, device=unique_keys.device, dtype=torch.bool + ) + ( + num_missing_in_storage, + missing_keys_in_storage, + missing_indices_in_storage, + _, + ) = storage.find(keys_for_storage, values_for_storage, founds=founds) + + h_num_missing_in_storage = num_missing_in_storage.cpu().item() + missing_indices_in_storage = missing_indices_in_storage[ + :h_num_missing_in_storage + ] + missing_keys_in_storage = missing_keys_in_storage[:h_num_missing_in_storage] + if h_num_missing_in_storage != 0: + if training: + embs_for_storage = values_for_storage[:, :emb_dim] + initializer( + embs_for_storage, + missing_indices_in_storage, + missing_keys_in_storage, + ) + values_for_storage[ + missing_indices_in_storage, emb_dim - val_dim : + ] = storage.init_optimizer_state() + else: + missing_keys = missing_keys[founds] + values_for_storage = values_for_storage[founds, :] + + update_cache( + cache, + storage, + keys_for_storage, + values_for_storage, + None, # prefetch does not update scores + ) diff --git a/corelib/dynamicemb/dynamicemb/optimizer.py b/corelib/dynamicemb/dynamicemb/optimizer.py index 1e272b5ad..8f594bf24 100644 --- a/corelib/dynamicemb/dynamicemb/optimizer.py +++ b/corelib/dynamicemb/dynamicemb/optimizer.py @@ -17,15 +17,22 @@ import copy import enum from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Union import torch # usort:skip from dynamicemb.dynamicemb_config import * -from dynamicemb_extensions import ( - DynamicEmbTable, +from dynamicemb_extensions import ( # dynamic_emb_sgd,; dynamic_emb_adam,; dynamic_emb_adagrad,; dynamic_emb_rowwise_adagrad, + dynamic_emb_adagrad_fused, + dynamic_emb_adagrad_with_pointer, dynamic_emb_adagrad_with_table, + dynamic_emb_adam_fused, + dynamic_emb_adam_with_pointer, dynamic_emb_adam_with_table, + dynamic_emb_rowwise_adagrad_fused, + dynamic_emb_rowwise_adagrad_with_pointer, dynamic_emb_rowwise_adagrad_with_table, + dynamic_emb_sgd_fused, + dynamic_emb_sgd_with_pointer, dynamic_emb_sgd_with_table, ) @@ -155,7 +162,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: ... @@ -167,6 +173,14 @@ def get_opt_args(self) -> Dict[str, Any]: def set_opt_args(self, args: Dict[str, Any]) -> None: ... + def need_gradient_clipping(self) -> bool: + return self._opt_args.gradient_clipping + + def clip_gradient(self, grads) -> None: + grads.clamp_( + min=-1 * self._opt_args.max_gradient, max=self._opt_args.max_gradient + ) + class SGDDynamicEmbeddingOptimizer(BaseDynamicEmbeddingOptimizer): def __init__( @@ -182,7 +196,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._hashtables: @@ -199,9 +212,14 @@ def update( indice = indices[i] num_indice = indice.shape[0] weight_dtype = torch_to_dyn_emb(table_option.embedding_dtype) - score = scores[i] if scores is not None else None + dynamic_emb_sgd_with_table( - ht, num_indice, indice, grad, lr, weight_dtype, score + ht, + num_indice, + indice, + grad, + lr, + weight_dtype, ) def get_opt_args(self): @@ -231,7 +249,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -254,7 +271,7 @@ def update( num_indice = indice.shape[0] weight_dtype = torch_to_dyn_emb(table_option.embedding_dtype) - score = scores[i] if scores is not None else None + dynamic_emb_adam_with_table( ht, num_indice, @@ -267,7 +284,6 @@ def update( weight_decay, self._iterations, weight_dtype, - score, ) def get_opt_args(self): @@ -310,7 +326,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -329,10 +344,9 @@ def update( num_indice = indice.shape[0] weight_dtype = torch_to_dyn_emb(table_option.embedding_dtype) - score = scores[i] if scores is not None else None dynamic_emb_adagrad_with_table( - ht, num_indice, indice, grad, lr, eps, weight_dtype, score + ht, num_indice, indice, grad, lr, eps, weight_dtype ) def get_opt_args(self): @@ -372,7 +386,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -390,10 +403,9 @@ def update( num_indice = indice.shape[0] weight_dtype = torch_to_dyn_emb(table_option.embedding_dtype) - score = scores[i] if scores is not None else None dynamic_emb_rowwise_adagrad_with_table( - ht, num_indice, indice, grad, lr, eps, weight_dtype, score + ht, num_indice, indice, grad, lr, eps, weight_dtype ) def get_opt_args(self): @@ -412,3 +424,424 @@ def set_opt_args(self, args: Dict[str, Any]): for table in self._state_dict["Gt"]: table.set_initial_optstate(initial_value) return + + +class BaseDynamicEmbeddingOptimizerV2(abc.ABC): + def __init__( + self, + opt_args: OptimizerArgs, + ) -> None: + self._opt_args: OptimizerArgs = copy.deepcopy(opt_args) + + @abc.abstractmethod + def update( + self, + grads: torch.Tensor, + embs: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + ... + + @abc.abstractmethod + def fused_update( + self, + grads: torch.Tensor, + values: torch.Tensor, + ) -> None: + ... + + @abc.abstractmethod + def fused_update_with_pointer( + self, + grads: torch.Tensor, + value_ptr: torch.Tensor, # pointers to embeddng + optimizer states + ) -> None: + ... + + @abc.abstractmethod + def get_opt_args(self) -> Dict[str, Any]: + ... + + @abc.abstractmethod + def set_opt_args(self, args: Dict[str, Any]) -> None: + ... + + @abc.abstractmethod + def get_state_dim(self, emb_dim: int) -> int: + """ + Get the state dim. + """ + + def set_learning_rate(self, new_lr) -> None: + self._opt_args.learning_rate = new_lr + return + + def get_initial_optim_states(self) -> float: + return self._opt_args.initial_accumulator_value + + def set_initial_optim_states(self, value: float) -> None: + self._opt_args.initial_accumulator_value = value + return + + def step(self) -> None: + pass + + +class SGDDynamicEmbeddingOptimizerV2(BaseDynamicEmbeddingOptimizerV2): + def __init__( + self, + opt_args: OptimizerArgs, + ) -> None: + super().__init__(opt_args) + + def update( + self, + grads: torch.Tensor, + embs: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + pass + # lr = self._opt_args.learning_rate + # dynamic_emb_sgd( + # grads.size(0), + # grads, + # embs, + # lr, + # ) + + def fused_update( + self, + grads: torch.Tensor, + values: torch.Tensor, + ) -> None: + lr = self._opt_args.learning_rate + dynamic_emb_sgd_fused( + grads, + values, + lr, + ) + + def fused_update_with_pointer( + self, + grads: torch.Tensor, + value_ptr: torch.Tensor, # pointers to embeddng + optimizer states + value_type, + ) -> None: + lr = self._opt_args.learning_rate + dynamic_emb_sgd_with_pointer( + grads, + value_ptr, + value_type, + lr, + ) + + def get_opt_args(self): + ret_args = {"lr": self._opt_args.learning_rate} + return ret_args + + def set_opt_args(self, args: Dict[str, Any]): + self._opt_args.learning_rate = get_required_arg(args, "lr") + return + + def get_state_dim(self, emb_dim: int) -> int: + """ + Get the state dim. + """ + return 0 + + +class AdamDynamicEmbeddingOptimizerV2(BaseDynamicEmbeddingOptimizerV2): + def __init__( + self, + opt_args: OptimizerArgs, + ) -> None: + super().__init__(opt_args) + self._iterations: int = 0 + + def step(self): + self._iterations += 1 + + def update( + self, + grads: torch.Tensor, + embs: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + pass + # assert states is not None + + # lr = self._opt_args.learning_rate + # beta1 = self._opt_args.beta1 + # beta2 = self._opt_args.beta2 + # weight_decay = self._opt_args.weight_decay + # eps = self._opt_args.eps + + # dynamic_emb_adam( + # grads.size(0), + # grads, + # embs, + # states, + # lr, + # beta1, + # beta2, + # eps, + # weight_decay, + # self._iterations, + # ) + + def fused_update( + self, + grads: torch.Tensor, + values: torch.Tensor, + ) -> None: + lr = self._opt_args.learning_rate + beta1 = self._opt_args.beta1 + beta2 = self._opt_args.beta2 + weight_decay = self._opt_args.weight_decay + eps = self._opt_args.eps + + dynamic_emb_adam_fused( + grads, + values, + lr, + beta1, + beta2, + eps, + weight_decay, + self._iterations, + ) + + def fused_update_with_pointer( + self, + grads: torch.Tensor, + value_ptr: torch.Tensor, # pointers to embeddng + optimizer states + value_type, + ) -> None: + lr = self._opt_args.learning_rate + beta1 = self._opt_args.beta1 + beta2 = self._opt_args.beta2 + weight_decay = self._opt_args.weight_decay + eps = self._opt_args.eps + + emb_dim = grads.size(1) + state_dim = self.get_state_dim(emb_dim) + + dynamic_emb_adam_with_pointer( + grads, + value_ptr, + value_type, + state_dim, + lr, + beta1, + beta2, + eps, + weight_decay, + self._iterations, + ) + + def get_opt_args(self): + ret_args = { + "lr": self._opt_args.learning_rate, + "iters": self._iterations, + "beta1": self._opt_args.beta1, + "beta2": self._opt_args.beta2, + "eps": self._opt_args.eps, + "weight_decay": self._opt_args.weight_decay, + } + return ret_args + + def set_opt_args(self, args: Dict[str, Any]): + self._opt_args.learning_rate = get_required_arg(args, "lr") + self._iterations = get_required_arg(args, "iters") + self._opt_args.beta1 = get_required_arg(args, "beta1") + self._opt_args.beta2 = get_required_arg(args, "beta2") + self._opt_args.eps = get_required_arg(args, "eps") + self._opt_args.weight_decay = get_required_arg(args, "weight_decay") + return + + def get_state_dim(self, emb_dim: int) -> int: + """ + Get the state dim. + """ + return emb_dim * 2 + + +class AdaGradDynamicEmbeddingOptimizerV2(BaseDynamicEmbeddingOptimizerV2): + def __init__( + self, + opt_args: OptimizerArgs, + ) -> None: + super().__init__(opt_args) + + def update( + self, + grads: torch.Tensor, + embs: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + pass + # lr = self._opt_args.learning_rate + # eps = self._opt_args.eps + + # dynamic_emb_adagrad( + # grads.size(0), + # grads, + # embs, + # states, + # lr, + # eps, + # ) + + def fused_update( + self, + grads: torch.Tensor, + values: torch.Tensor, + ) -> None: + lr = self._opt_args.learning_rate + eps = self._opt_args.eps + + dynamic_emb_adagrad_fused( + grads, + values, + lr, + eps, + ) + + def fused_update_with_pointer( + self, + grads: torch.Tensor, + value_ptr: torch.Tensor, # pointers to embeddng + optimizer states + value_type, + ) -> None: + lr = self._opt_args.learning_rate + eps = self._opt_args.eps + + emb_dim = grads.size(1) + state_dim = self.get_state_dim(emb_dim) + + dynamic_emb_adagrad_with_pointer( + grads, + value_ptr, + value_type, + state_dim, + lr, + eps, + ) + + def get_opt_args(self): + ret_args = { + "lr": self._opt_args.learning_rate, + "eps": self._opt_args.eps, + "initial_accumulator_value": self._opt_args.initial_accumulator_value, + } + return ret_args + + def set_opt_args(self, args: Dict[str, Any]): + self._opt_args.learning_rate = get_required_arg(args, "lr") + self._opt_args.eps = get_required_arg(args, "eps") + initial_value = get_required_arg(args, "initial_accumulator_value") + self._opt_args.initial_accumulator_value = initial_value + return + + def get_state_dim(self, emb_dim: int) -> int: + """ + Get the state dim. + """ + return emb_dim + + +class RowWiseAdaGradDynamicEmbeddingOptimizerV2(BaseDynamicEmbeddingOptimizerV2): + def __init__( + self, + opt_args: OptimizerArgs, + emb_dtype: torch.dtype, + ) -> None: + super().__init__(opt_args) + + DTYPE_NUM_BYTES: Dict[torch.dtype, int] = { + torch.float32: 4, + torch.float16: 2, + torch.bfloat16: 2, + } + self._optim_state_dim = 16 // DTYPE_NUM_BYTES[emb_dtype] + + def update( + self, + grads: torch.Tensor, + embs: torch.Tensor, + states: Optional[torch.Tensor], + ) -> None: + pass + # lr = self._opt_args.learning_rate + # eps = self._opt_args.eps + + # dynamic_emb_rowwise_adagrad( + # grads.size(0), + # grads, + # embs, + # states, + # lr, + # eps, + # ) + + def fused_update( + self, + grads: torch.Tensor, + values: torch.Tensor, + ) -> None: + lr = self._opt_args.learning_rate + eps = self._opt_args.eps + + emb_dim = grads.size(1) + self.get_state_dim(emb_dim) + + dynamic_emb_rowwise_adagrad_fused( + grads.size(0), + grads, + values, + lr, + eps, + ) + + def fused_update_with_pointer( + self, + grads: torch.Tensor, + value_ptr: torch.Tensor, # pointers to embeddng + optimizer states + value_type, + ) -> None: + lr = self._opt_args.learning_rate + eps = self._opt_args.eps + + emb_dim = grads.size(1) + state_dim = self.get_state_dim(emb_dim) + + dynamic_emb_rowwise_adagrad_with_pointer( + grads.size(0), + grads, + value_ptr, + value_type, + state_dim, + lr, + eps, + ) + + def get_opt_args(self): + ret_args = { + "lr": self._opt_args.learning_rate, + "eps": self._opt_args.eps, + "initial_accumulator_value": self._opt_args.initial_accumulator_value, + } + return ret_args + + def set_opt_args(self, args: Dict[str, Any]): + self._opt_args.learning_rate = get_required_arg(args, "lr") + self._opt_args.eps = get_required_arg(args, "eps") + initial_value = get_required_arg(args, "initial_accumulator_value") + self._opt_args.initial_accumulator_value = initial_value + return + + def get_state_dim(self, emb_dim: int) -> int: + """ + Get the state dim. + """ + return self._optim_state_dim diff --git a/corelib/dynamicemb/dynamicemb/scored_hashtable.py b/corelib/dynamicemb/dynamicemb/scored_hashtable.py new file mode 100644 index 000000000..2418f4f1b --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/scored_hashtable.py @@ -0,0 +1,1010 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum +import os +import warnings +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import numpy as np +import torch +import torch.distributed as dist +from dynamicemb.dynamicemb_config import dtype_to_bytes +from dynamicemb.types import ( + COUNTER_TYPE, + KEY_TYPE, + SCORE_TYPE, + MemoryType, + torch_dtype_to_np_dtype, +) +from dynamicemb_extensions import ( + ScorePolicy, + device_timestamp, + table_erase, + table_export_batch, + table_insert, + table_insert_and_evict, + table_lookup, + table_partition, +) + + +@dataclass(frozen=True) +class ScoreSpec: + name: str + policy: ScorePolicy # How to set the new score, this is the default behavior. + dtype: torch.dtype = torch.uint64 + priority: int = 0 # If multiple scores exist, the one with lower priority will be reduced first. + is_reduction: bool = True # Whether it is reduced + + +@dataclass +class ScoreArg: + name: str + value: Optional[torch.Tensor] = None + is_return: bool = ( + False # Whether return the new score, if true will overwrite the `value` + ) + policy: Optional[ + ScorePolicy + ] = None # How to set the new score, and providing this will override the default. + + +@enum.unique +class ProbingType(enum.Enum): + LINEAR = "linear" + CHAINED = "separate_chain" + + +@enum.unique +class ReductionType(enum.Enum): + LINEAR = "linear" + DOUBLY_LINKED = "doubly_linked" + + +class ScoredHashTable(abc.ABC): + """ + Multiple scores are supported. + If a hash collision cannot be resolved during insertion, the key with the lower score will be evicted. + The value of the table is the index/ID of each key in the table, which is read-only. + """ + + @property + @abc.abstractmethod + def key_type(self) -> torch.dtype: + """ + Return the key type. + """ + + @property + def index_type(self) -> torch.dtype: + """ + Return the index type. + """ + return torch.int64 + + @property + @abc.abstractmethod + def score_specs( + self, + score_names: List[str] = None, + ) -> List[ScoreSpec]: + """ + Return the score specifics. + """ + + @property + def result_type(self) -> torch.dtype: + """ + Return the insert-result type. + """ + return torch.uint8 + + @abc.abstractmethod + def lookup( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + founds: Optional[torch.Tensor], + indices: torch.Tensor = None, + ) -> None: + """ + TODO: kernel fusion + Argument:: + missing_keys: torch.Tensor=None + missing_indices: torch.Tensor=None + missing_scores: List[ScoreArg]=None + Returns: + num_missing: int + """ + + @abc.abstractmethod + def insert( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> None: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + @abc.abstractmethod + def insert_and_evict( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + num_evicted: int + evicted_keys: torch.Tensor + evicted_indices: torch.Tensor + evicted_scores: List[torch.Tensor] + return num_evicted, evicted_keys, evicted_indices, evicted_scores + + @abc.abstractmethod + def erase( + self, + keys: torch.Tensor, + ) -> None: + """ + Erase Keys + """ + + @abc.abstractmethod + def load( + self, + key_file: str, + score_files: Dict[str, str], + ) -> None: + """ + Load keys and scores from input file path. + + Args: + key_file (str): the file path of keys. + score_files: Dict[str, str]: Dict from score name to score file path. + """ + + @abc.abstractmethod + def dump( + self, + key_file: str, + score_files: Dict[str, str], + ) -> None: + """ + Dump keys and scores to output file path. + + Args: + key_file (str): the file path of keys. + score_files: Dict[str, str]: Dict from score name to score file path. + """ + + @abc.abstractmethod + def capacity(self) -> int: + """ + Return the capacity of the table. + """ + + @abc.abstractmethod + def size(self) -> int: + """ + Return the size of the table. + """ + + @abc.abstractmethod + def load_factor(self) -> float: + """ + Return the load factor of the table. + """ + + @abc.abstractmethod + def reserve( + self, + target_capacity, + ): + """ + Table's growth is controlled outside. + """ + + @abc.abstractmethod + def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + + Args: + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + + +class GroupedScoredHashTable(abc.ABC): + """ + Multiple scores are supported. + If a hash collision cannot be resolved during insertion, the key with the lower score will be evicted. + The value of the table is the index/ID of each key in the table, which is read-only. + + key_type, index_type, offset_type, score_specs, result_type are the same for tables in the same group. + """ + + @property + @abc.abstractmethod + def key_type(self) -> torch.dtype: + """ + Return the key type. + """ + + @property + def index_type(self) -> torch.dtype: + """ + Return the index type. + """ + return torch.int64 + + @property + @abc.abstractmethod + def score_specs( + self, + score_names: List[str] = None, + ) -> List[ScoreSpec]: + """ + Return the score specifics. + """ + + @property + def result_type(self) -> torch.dtype: + """ + Return the insert-result type. + """ + return torch.uint8 + + @property + def offset_type(self) -> torch.dtype: + """ + Return the offset type, used for e.g. table range. + """ + return torch.int64 + + @property + @abc.abstractmethod + def table_names( + self, + table_names: List[str] = None, + ) -> List[str]: + """ + Return the table names in the group. + """ + + @abc.abstractmethod + def lookup( + self, + table_range: torch.Tensor, + keys: torch.Tensor, + scores: List[ScoreArg], + founds: Optional[torch.Tensor], + indices: torch.Tensor = None, + ) -> None: + """ + TODO: kernel fusion + Argument: + missing_table_range: torch.Tensor + missing_keys: torch.Tensor=None + missing_indices: torch.Tensor=None + missing_scores: List[ScoreArg]=None + Returns: + num_missing: int + """ + + @abc.abstractmethod + def insert( + self, + table_range: torch.Tensor, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> None: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + @abc.abstractmethod + def insert_and_evict( + self, + table_range: torch.Tensor, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + num_evicted: int + missing_table_range: torch.Tensor + evicted_keys: torch.Tensor + evicted_indices: torch.Tensor + evicted_scores: List[torch.Tensor] + return ( + num_evicted, + missing_table_range, + evicted_keys, + evicted_indices, + evicted_scores, + ) + + @abc.abstractmethod + def erase( + self, + table_range: torch.Tensor, + keys: torch.Tensor, + ) -> None: + """ + Erase Keys. + """ + + @abc.abstractmethod + def load( + self, + table_names: List[str], + key_files: List[str], + score_files: List[Dict[str, str]], + ) -> None: + """ + Load keys and scores from input file path. + + Args: + table_names: List[str] + key_files: List[str], + score_files: List[Dict[str, str]]: Dict from score name to score file path. + """ + + @abc.abstractmethod + def dump( + self, + table_names: List[str], + key_files: List[str], + score_files: List[Dict[str, str]], + ) -> None: + """ + Dump keys and scores to output file path. + + Args: + table_names: List[str] + key_files: List[str], + score_files: List[Dict[str, str]]: Dict from score name to score file path. + """ + + @abc.abstractmethod + def capacity(self, table_name: str) -> int: + """ + Return the capacity of the table. + """ + + @abc.abstractmethod + def size(self, table_name: str) -> int: + """ + Return the size of the table. + """ + + @abc.abstractmethod + def load_factor(self, table_name: str) -> float: + """ + Return the load factor of the table. + """ + + @abc.abstractmethod + def reserve( + self, + table_name: str, + target_capacity: int, + ): + """ + Table's growth is controlled outside. + """ + + @abc.abstractmethod + def memory_usage(self, table_name: str, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + + Args: + table_name: str, + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + + +def uint64_to_int64(x): + return x if x < (1 << 63) else x - (1 << 64) + + +def murmur3_hash_64bits(key: int) -> int: + """ """ + k = key & 0xFFFFFFFFFFFFFFFF + + k ^= k >> 33 + k = (k * 0xFF51AFD7ED558CCD) & 0xFFFFFFFFFFFFFFFF + + k ^= k >> 33 + k = (k * 0xC4CEB9FE1A85EC53) & 0xFFFFFFFFFFFFFFFF + + k ^= k >> 33 + + return k + + +class LinearBucketTable(ScoredHashTable): + def __init__( + self, + capacity: int, + score_specs: List[ScoreSpec], + key_type: torch.dtype = torch.int64, + bucket_capacity: Optional[int] = None, + device: torch.device = None, + ): + self.device = ( + device + if device is not None + else torch.device("cuda", torch.cuda.current_device()) + ) + + # key type + self.key_type_ = key_type + accepted_key_types = {torch.int64, torch.uint64} + assert ( + key_type in accepted_key_types + ), "Only accept 64 bits integer as key's type." + + # score type + assert ( + len(score_specs) >= 1 and len(score_specs) <= 1 + ), "Only support at least one and at most one ScoreSpec in this version." + self.score_specs_ = sorted( + score_specs, key=lambda x: (not x.is_reduction, x.priority) + ) + assert self.score_specs_[0].is_reduction is True + accepted_score_types = {torch.uint64} + self.score_types_ = [] + self.score_names_ = [] + for score_spec in self.score_specs_: + assert ( + score_spec.dtype in accepted_score_types + ), "Only accept 64 bits unsigned integer as score's type." + self.score_types_.append(score_spec.dtype) + self.score_names_.append(score_spec.name) + + # digest type + self.digest_type_ = torch.uint8 + + # capacity & bucket capacity + if bucket_capacity is None: + bucket_capacity = 128 + + assert capacity > 0 and bucket_capacity > 0 and capacity >= bucket_capacity + max_load_bytes = 16 + digest_load_dim = max_load_bytes // dtype_to_bytes(self.digest_type_) + if bucket_capacity % digest_load_dim == 0: + self.bucket_capacity_ = bucket_capacity + else: + self.bucket_capacity_ = ( + (bucket_capacity + digest_load_dim - 1) // digest_load_dim + ) * digest_load_dim + # self.bucket_capacity_ = _next_power_of_2(self.bucket_capacity_) + + if self.bucket_capacity_ != bucket_capacity: + warnings.warn( + f"Bucket capacity is rounded from {bucket_capacity} to {self.bucket_capacity_}.", + UserWarning, + ) + self.num_buckets_ = ( + capacity + self.bucket_capacity_ - 1 + ) // self.bucket_capacity_ + self.capacity_ = self.num_buckets_ * self.bucket_capacity_ + if self.capacity_ != capacity: + warnings.warn( + f"Table capacity is rounded from {capacity} to {self.capacity_}.", + UserWarning, + ) + + # storage + self.fileds_type_ = [self.key_type_, self.digest_type_] + self.score_types_ + fields_byte = [dtype_to_bytes(x) for x in self.fileds_type_] + + self.storage_bytes_ = ( + sum(fields_byte) * self.bucket_capacity_ * self.num_buckets_ + ) + self.table_storage_ = torch.empty( + self.storage_bytes_, dtype=torch.uint8, device=self.device + ) + + self.keys_, self.digests_, *self.scores_list = table_partition( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + self.num_buckets_, + ) + self._init_table() + + self.bucket_sizes = torch.zeros( + self.num_buckets_, dtype=torch.int32, device=self.device + ) + + def _init_table( + self, + ): + # init keys + empty_key = 0xFFFFFFFFFFFFFFFF + if self.key_type_ == torch.int64: + empty_key = uint64_to_int64(empty_key) + self.keys_.fill_(empty_key) + + # init scores + empty_score = 0 + for scores in self.scores_list: + scores.fill_(empty_score) + + # init digest + empty_digest = (murmur3_hash_64bits(empty_key) >> 32) & 0xFF + self.digests_.fill_(empty_digest) + + @property + def key_type(self) -> torch.dtype: + """ + Return the key type. + """ + return self.key_type_ + + @property + def score_specs( + self, + score_names: List[str] = None, + ) -> List[ScoreSpec]: + """ + Return the score specifics. + """ + return self.score_specs_ + + def _parse_scores( + self, + scores: List[ScoreArg], + ) -> Tuple[List[torch.Tensor], List[ScorePolicy], List[bool]]: + scores_ = [None for _ in self.score_names_] + policies = [ScorePolicy.CONST for _ in self.score_names_] + is_returns = [False for _ in self.score_names_] + + for score in scores: + index = self.score_names_.index(score.name) + if score.is_return: + assert score.value is not None + scores_[index] = score.value + policies[index] = ( + score.policy + if score.policy is not None + else self.score_specs_[index].policy + ) + is_returns[index] = score.is_return + + if score.policy == ScorePolicy.GLOBAL_TIMER: + assert ( + self.score_specs_[index].dtype == torch.uint64 + ), "Global timer can only work for torch.uint64" + + return scores_, policies, is_returns + + def lookup( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + founds: Optional[torch.Tensor], + indices: torch.Tensor = None, + ) -> None: + """ + TODO: kernel fusion + Argument:: + missing_keys: torch.Tensor=None + missing_indices: torch.Tensor=None + missing_scores: List[ScoreArg]=None + Returns: + num_missing: int + """ + scores_, policies, is_returns = self._parse_scores(scores) + + table_lookup( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + keys, + scores_, + policies, + is_returns, + founds, + indices, + ) + + def insert( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> None: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + scores_, policies, is_returns = self._parse_scores(scores) + + table_insert( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + self.bucket_sizes, + keys, + scores_, + policies, + is_returns, + indices, + insert_results, + ) + + def insert_and_evict( + self, + keys: torch.Tensor, + scores: List[ScoreArg], + indices: Optional[torch.Tensor] = None, + insert_results: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """ + Keys have to be unique. + Indices is output buffer if provided. + """ + + scores_, policies, is_returns = self._parse_scores(scores) + + batch = keys.numel() + num_evicted = torch.zeros(1, dtype=COUNTER_TYPE, device=keys.device) + evicted_keys = torch.empty(batch, dtype=self.key_type_, device=keys.device) + evicted_indices = torch.empty(batch, dtype=self.index_type, device=keys.device) + evicted_scores_list = [ + torch.empty(batch, dtype=dtype, device=keys.device) + for dtype in self.score_types_ + ] + + table_insert_and_evict( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + self.bucket_sizes, + keys, + scores_, + policies, + is_returns, + insert_results, + indices, + num_evicted, + evicted_keys, + evicted_indices, + evicted_scores_list, + ) + + h_num_evicted = num_evicted.cpu().item() + return ( + h_num_evicted, + evicted_keys[:h_num_evicted], + evicted_indices[:h_num_evicted], + [evicted_scores[:h_num_evicted] for evicted_scores in evicted_scores_list], + ) + + def erase( + self, + keys: torch.Tensor, + ) -> None: + """ + Erase Keys + """ + table_erase( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + self.bucket_sizes, + keys, + ) + + def load( + self, + key_file: str, + score_files: Dict[str, str], + ) -> None: + """ + Load keys and scores from input file path. + + Args: + key_file (str): the file path of keys. + score_files: Dict[str, str]: Dict from score name to score file path. + """ + + for score_name in self.score_names_: + if score_name not in score_files or not os.path.exists( + score_files[score_name] + ): + print( + f"Will not load scores for {score_name}, as not provide the file path or file path not existed." + ) + + fkey = open(key_file, "rb") + + fscores: Dict[str, Any] = {} + for score_name, score_path in score_files.items(): + if score_name not in self.score_names_: + print( + f"Score name {score_name} not existed, will not load from {score_path}." + ) + elif os.path.exists(score_path): + fscores[score_name] = open(score_path, "rb") + + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + num_keys = os.path.getsize(key_file) // KEY_TYPE.itemsize + + for score_name in fscores.keys(): + num_scores = os.path.getsize(score_files[score_name]) // SCORE_TYPE.itemsize + + if num_keys != num_scores: + raise ValueError( + f"The number of keys({num_keys}) in {key_file} does not match with number of scores({num_keys}) in {score_files[score_name]}." + ) + + world_size = dist.get_world_size() if dist.is_initialized() else 1 + rank = dist.get_rank() if dist.is_initialized() else 0 + + dump_timestamp = device_timestamp() + + batch_size = 65536 + for start in range(0, num_keys, batch_size): + num_keys_to_read = min(num_keys - start, batch_size) + keys_bytes = fkey.read(KEY_TYPE.itemsize * num_keys_to_read) + + score_bytes_dict: Dict[str, Any] = {} + for score_name in fscores.keys(): + score_bytes_dict[score_name] = fscores[score_name].read( + SCORE_TYPE.itemsize * num_keys_to_read + ) + + keys = torch.tensor( + np.frombuffer(keys_bytes, dtype=torch_dtype_to_np_dtype[KEY_TYPE]), + dtype=KEY_TYPE, + device=device, + ) + scores_dict: Dict[str, torch.Tensor] = {} + for score_name, score_bytes in score_bytes_dict.items(): + scores = torch.tensor( + np.frombuffer( + score_bytes, dtype=torch_dtype_to_np_dtype[SCORE_TYPE] + ), + dtype=SCORE_TYPE, + device=device, + ) + index = self.score_names_.index(score_name) + if self.score_specs_[index].policy == ScorePolicy.GLOBAL_TIMER: + scores = torch.clamp(dump_timestamp - scores, min=0) + scores_dict[score_name] = scores + + if world_size > 1: + masks = keys % world_size == rank + keys = keys[masks] + for score_name in scores_dict: + scores_dict[score_name] = scores_dict[score_name][masks] + + score_args = [] + for score_name, scores in scores_dict.items(): + score_args.append( + ScoreArg(name=score_name, value=scores, policy=ScorePolicy.ASSIGN) + ) + self.insert(keys, score_args) + + fkey.close() + for name in fscores.keys(): + fscores[name].close() + + def _batched_export_keys_scores( + self, + score_names: List[str], + target_device: torch.device, + batch_size: int = 65536, + ) -> Iterator[Tuple[torch.Tensor, Dict[str, torch.Tensor]]]: + """ + export keys, {score_name: scores} + """ + + search_capacity = self.capacity_ + + offset = 0 + + device = self.device + + key_dtype = self.key_type_ + score_dtype = torch.uint64 + + while offset < search_capacity: + batch_ = min(batch_size, search_capacity - offset) + + keys = torch.empty(batch_, dtype=key_dtype, device=device) + scores_list = [] + for score_name in self.score_names_: + if score_name in score_names: + scores_list.append( + torch.zeros(batch_, dtype=score_dtype, device=device) + ) + else: + scores_list.append(None) + d_counter = torch.zeros(1, dtype=COUNTER_TYPE, device=device) + + table_export_batch( + self.table_storage_, + self.fileds_type_, + self.bucket_capacity_, + batch_, + offset, + d_counter, + keys, + scores_list, + ) + + actual_length = d_counter.item() + if actual_length > 0: + named_scores: Dict[str, torch.Tensor] = {} + for score_name in score_names: + index = self.score_names_.index(score_name) + scores_ = scores_list[index] + named_scores[score_name] = ( + scores_[:actual_length].to(SCORE_TYPE).to(target_device) + ) + + yield ( + keys[:actual_length].to(KEY_TYPE).to(target_device), + named_scores, + ) + offset += batch_size + + def dump( + self, + key_file: str, + score_files: Dict[str, str], + ) -> None: + """ + Dump keys and scores to output file path. + + Args: + key_file (str): the file path of keys. + score_files: Dict[str, str]: Dict from score name to score file path. + """ + + fkey = open(key_file, "wb") + fscores: Dict[str, Any] = {} + for score_name, score_path in score_files.items(): + if score_name not in self.score_names_: + print( + f"Score name {score_name} not existed, will not dump to {score_path}." + ) + else: + fscores[score_name] = open(score_path, "wb") + + dump_timestamp = device_timestamp() + + for keys, named_scores in self._batched_export_keys_scores( + fscores.keys(), self.device + ): + fkey.write(keys.cpu().numpy().tobytes()) + for name, scores in named_scores.items(): + index = self.score_names_.index(name) + if self.score_specs_[index].policy == ScorePolicy.GLOBAL_TIMER: + scores = dump_timestamp - scores + fscores[name].write(scores.cpu().numpy().tobytes()) + + fkey.close() + for fscore in fscores.values(): + fscore.close() + + return + + def capacity(self) -> int: + """ + Return the capacity of the table. + """ + return self.capacity_ + + def size(self) -> int: + """ + Return the size of the table. + """ + return self.bucket_sizes.sum() + + def load_factor(self) -> float: + """ + Return the load factor of the table. + """ + return self.bucket_sizes.sum() / self.capacity_ + + def reserve( + self, + target_capacity, + ): + """ + Table's growth is controlled outside. + """ + raise NotImplementedError + + def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + + Args: + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + return ( + self.storage_bytes_ + + self.bucket_sizes.numel() * self.bucket_sizes.element_size() + ) + + +def get_scored_table( + capacity: int, + bucket_capacity: Optional[int] = None, + key_type: Optional[torch.dtype] = torch.int64, + score_specs: List[ScoreSpec] = [ + ScoreSpec(name="timestamp", policy=ScorePolicy.GLOBAL_TIMER) + ], + device: torch.device = None, + probing_type=ProbingType.LINEAR, + reduction_type=ReductionType.LINEAR, + bucket_load_factor=0.5, # used when probing_type=ProbingType.CHAINED +) -> ScoredHashTable: + if probing_type == ProbingType.LINEAR and reduction_type == ReductionType.LINEAR: + return LinearBucketTable( + capacity, + score_specs, + key_type=key_type, + bucket_capacity=bucket_capacity, + device=device, + ) + else: + raise NotImplementedError + + +def get_grouped_scored_table( + capacities: List[int], + bucket_capacity: Optional[List[int]] = None, + key_type: Optional[torch.dtype] = torch.int64, + score_specs: List[ScoreSpec] = [ + ScoreSpec(name="timestamp", policy=ScorePolicy.GLOBAL_TIMER) + ], + device: torch.device = None, + probing_type=ProbingType.LINEAR, + reduction_type=ReductionType.LINEAR, + bucket_load_factor=0.5, # used when probing_type=ProbingType.CHAINED +) -> GroupedScoredHashTable: + raise NotImplementedError diff --git a/corelib/dynamicemb/dynamicemb/shard/embedding.py b/corelib/dynamicemb/dynamicemb/shard/embedding.py index bebc7960e..3f423ca26 100644 --- a/corelib/dynamicemb/dynamicemb/shard/embedding.py +++ b/corelib/dynamicemb/dynamicemb/shard/embedding.py @@ -45,16 +45,48 @@ from torchrec.modules.embedding_modules import EmbeddingCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from ..dynamicemb_config import DynamicEmbKernel +from ..dynamicemb_config import DynamicEmbKernel, DynamicEmbScoreStrategy from ..planner.rw_sharding import RwSequenceDynamicEmbeddingSharding from ..unique_op import UniqueOp +class DynamicEmbeddingCollectionContext(EmbeddingCollectionContext): + """Extended EmbeddingCollectionContext that includes frequency_counters for LFU strategy.""" + + def __init__( + self, + sharding_contexts: Optional[List[SequenceShardingContext]] = None, + input_features: Optional[List[KeyedJaggedTensor]] = None, + reverse_indices: Optional[List[torch.Tensor]] = None, + seq_vbe_ctx: Optional[List] = None, + frequency_counters: Optional[List[torch.Tensor]] = None, + ) -> None: + super().__init__( + sharding_contexts, input_features, reverse_indices, seq_vbe_ctx + ) + self.frequency_counters: List[torch.Tensor] = frequency_counters or [] + + class ShardedDynamicEmbeddingCollection(ShardedEmbeddingCollection): supported_compute_kernels: List[str] = [ kernel.value for kernel in EmbeddingComputeKernel ] + [DynamicEmbKernel] + def __init__( + self, + *args, + score_strategy: Optional[DynamicEmbScoreStrategy] = None, + has_admit_strategy: bool = False, + **kwargs, + ): + super().__init__(*args, **kwargs) + # Store the global score strategy + self._score_strategy = score_strategy + self._is_lfu_enabled = ( + (score_strategy == DynamicEmbScoreStrategy.LFU) if score_strategy else False + ) + self._has_admit_strategy = has_admit_strategy + @classmethod def create_embedding_sharding( cls, @@ -103,6 +135,9 @@ def _create_hash_size_info( ctx: Optional[EmbeddingCollectionContext] = None, ) -> None: super()._create_hash_size_info(feature_names) + + # _is_lfu_enabled is already set in __init__ from score_strategy parameter + if self._use_index_dedup: reserve_keys = torch.tensor(2, dtype=torch.int64, device=self._device) reserve_vals = torch.tensor(2, dtype=torch.uint64, device=self._device) @@ -143,7 +178,7 @@ def _create_hash_size_info( def _dedup_indices( self, - ctx: EmbeddingCollectionContext, + ctx: DynamicEmbeddingCollectionContext, input_feature_splits: List[KeyedJaggedTensor], ) -> List[KeyedJaggedTensor]: with record_function("## dedup_ec_indices ##"): @@ -156,7 +191,7 @@ def _dedup_indices( d_table_offset = self.get_buffer( f"_nonfuse_table_feature_offsets_device_{i}" ) - + input_feature._values = input_feature._values.contiguous() # for debug # hash_size_cumsum = self.get_buffer( # f"_hash_size_cumsum_tensor_{i}" @@ -183,7 +218,6 @@ def _dedup_indices( offsets = input_feature.offsets() lengths = input_feature.lengths() dtype_convert = False - torch.int64 if indices.dtype != torch.int64: indices.dtype indices_input = indices.to(torch.int64) @@ -211,29 +245,67 @@ def _dedup_indices( new_offsets = torch.empty_like(offsets, device=self._device) new_lengths = torch.empty_like(lengths, device=self._device) - dedup_input_indices( - indices_input, - offsets, - h_table_offset, - d_table_offset, - table_num, - local_batchsize, - reverse_idx, - h_unique_nums, - d_unique_nums, - h_unique_offsets, - d_unique_offsets, - unique_idx_list, - new_offsets, - new_lengths, - self._device_num_sms, - self._unique_op, - ) + + # Only create frequency_counters if LFU strategy is enabled + # For non-LFU strategies, pass empty tensor (C++ extension will check size) + if self._is_lfu_enabled or self._has_admit_strategy: + # TODO: use only one frequency_counters tensor for all tables + # frequency_counters = torch.zeros_like( + # indices_input, device=self._device, dtype=torch.uint64 + # ) + frequency_counters_list = [ + torch.zeros_like( + indices_input, dtype=torch.uint64, device=self._device + ) + for i in range(table_num) + ] + dedup_input_indices( + indices_input, + offsets, + h_table_offset, + d_table_offset, + table_num, + local_batchsize, + reverse_idx, + h_unique_nums, + d_unique_nums, + h_unique_offsets, + d_unique_offsets, + unique_idx_list, + new_offsets, + new_lengths, + self._device_num_sms, + self._unique_op, + frequency_counters_list, + ) + else: + # Empty tensor for non-LFU and non-admit strategies + dedup_input_indices( + indices_input, + offsets, + h_table_offset, + d_table_offset, + table_num, + local_batchsize, + reverse_idx, + h_unique_nums, + d_unique_nums, + h_unique_offsets, + d_unique_offsets, + unique_idx_list, + new_offsets, + new_lengths, + self._device_num_sms, + self._unique_op, + ) + unique_num = h_unique_offsets[-1].item() unique_idx = torch.empty( unique_num, dtype=torch.int64, device=indices.device ) - + frequency_counters = torch.empty( + unique_num, device=self._device, dtype=torch.uint64 + ) # TODO: check non_blocking=True is valid for device tensor to device tensor for i in range(table_num): start_pos = h_unique_offsets[i].item() @@ -242,6 +314,10 @@ def _dedup_indices( unique_idx[start_pos:end_pos].copy_( unique_idx_list[i][:length], non_blocking=True ) + if self._is_lfu_enabled or self._has_admit_strategy: + frequency_counters[start_pos:end_pos].copy_( + frequency_counters_list[i][:length], non_blocking=True + ) if dtype_convert: unique_idx_out = torch.empty( @@ -257,15 +333,21 @@ def _dedup_indices( offsets=new_offsets, values=unique_idx_out, ) - ctx.input_features.append(input_feature) ctx.reverse_indices.append(reverse_idx) + # Only store frequency_counters if LFU or admit strategy is enabled + if self._is_lfu_enabled or self._has_admit_strategy: + ctx.frequency_counters.append(frequency_counters) + assert frequency_counters.size(0) == unique_idx_out.size( + 0 + ), f"Frequency counters size {frequency_counters.size(0)} doesn't match unique indices size {unique_idx_out.size(0)}" + features_by_shards.append(dedup_features) return features_by_shards def input_dist( self, - ctx: EmbeddingCollectionContext, + ctx: DynamicEmbeddingCollectionContext, features: KeyedJaggedTensor, ) -> Awaitable[Awaitable[KJTList]]: if self._has_uninitialized_input_dist: @@ -287,7 +369,20 @@ def input_dist( features_by_shards = self._dedup_indices(ctx, features_by_shards) awaitables = [] - for input_dist, features in zip(self._input_dists, features_by_shards): + for i, (input_dist, features) in enumerate( + zip(self._input_dists, features_by_shards) + ): + # Attach frequency counters as weights if LFU strategy is enabled + if ( + self._use_index_dedup + and (self._is_lfu_enabled or self._has_admit_strategy) + and len(ctx.frequency_counters) > i + ): + frequency_counters = ctx.frequency_counters[i] + features._weights = frequency_counters.float() + else: + features._weights = None + awaitables.append(input_dist(features)) ctx.sharding_contexts.append( SequenceShardingContext( @@ -303,6 +398,18 @@ def input_dist( self._compute_sequence_vbe_context(ctx, unpadded_features) return KJTListSplitsAwaitable(awaitables, ctx) + # def create_context(self) -> DynamicEmbeddingCollectionContext: + # return DynamicEmbeddingCollectionContext(sharding_contexts=[]) + + def create_context(self) -> DynamicEmbeddingCollectionContext: + # pre-allocate frequency_counters list, ensure all ranks have the same structure + frequency_counters = ( + [] if not (self._is_lfu_enabled or self._has_admit_strategy) else None + ) + return DynamicEmbeddingCollectionContext( + sharding_contexts=[], frequency_counters=frequency_counters + ) + class DynamicEmbeddingCollectionSharder(EmbeddingCollectionSharder): """ @@ -324,6 +431,27 @@ def shard( device: Optional[torch.device] = None, module_fqn: Optional[str] = None, ) -> ShardedEmbeddingCollection: + # Extract global score_strategy from params (only once, as it's a global configuration) + # Strategy is expected to be consistent across all tables + global_score_strategy = None + has_admit_strategy = False + if global_score_strategy is None: + for param_name, param_sharding in params.items(): + if ( + hasattr(param_sharding, "dynamicemb_options") + and param_sharding.dynamicemb_options + ): + if param_sharding.dynamicemb_options.score_strategy is not None: + global_score_strategy = ( + param_sharding.dynamicemb_options.score_strategy + ) + + if param_sharding.dynamicemb_options.admit_strategy is not None: + has_admit_strategy = True + + break + + # Pass score_strategy directly as a parameter to ShardedDynamicEmbeddingCollection return ShardedDynamicEmbeddingCollection( module, params, @@ -332,4 +460,6 @@ def shard( device, qcomm_codecs_registry=self.qcomm_codecs_registry, use_index_dedup=self._use_index_dedup, + score_strategy=global_score_strategy, # Pass as direct parameter + has_admit_strategy=has_admit_strategy, ) diff --git a/corelib/dynamicemb/dynamicemb/types.py b/corelib/dynamicemb/dynamicemb/types.py new file mode 100644 index 000000000..6756b6254 --- /dev/null +++ b/corelib/dynamicemb/dynamicemb/types.py @@ -0,0 +1,409 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import enum +from dataclasses import dataclass +from typing import Generic, Optional, Tuple, TypeVar + +import numpy as np +import torch +from dynamicemb_extensions import InitializerArgs + + +@enum.unique +class MemoryType(enum.Enum): + DEVICE = "device" # memory allocated using cudaMalloc/cudaMallocAsync + MANAGED = "managed" # memory allocated using cudaMallocManaged + PINNED_HOST = "pinned_host" # memory allocated using cudaHostAlloc/cudaMallocHost + HOST = "host" # system memory allocated using e.g. malloc. + + +class DynamicEmbInitializerMode(enum.Enum): + """ + Enumeration for different modes of initializing dynamic embedding vector values. + + Attributes + ---------- + NORMAL : str + Normal Distribution. + UNIFORM : str + Uniform distribution of random values. + CONSTANT : str + All dynamic embedding vector values are a given constant. + DEBUG : str + Debug value generation mode for testing. + """ + + NORMAL = "normal" + TRUNCATED_NORMAL = "truncated_normal" + UNIFORM = "uniform" + CONSTANT = "constant" + DEBUG = "debug" + + +@dataclass +class DynamicEmbInitializerArgs: + """ + Arguments for initializing dynamic embedding vector values. + + Attributes + ---------- + mode : DynamicEmbInitializerMode + The mode of initialization, one of the DynamicEmbInitializerMode values. + mean : float, optional + The mean value for (truncated) normal distributions. Defaults to 0.0. + std_dev : float, optional + The standard deviation for (truncated) normal distributions. Defaults to 1.0. + lower : float, optional + The lower bound for uniform/truncated_normal distribution. Defaults to 0.0. + upper : float, optional + The upper bound for uniform/truncated_normal distribution. Defaults to 1.0. + value : float, optional + The constant value for constant initialization. Defaults to 0.0. + """ + + mode: DynamicEmbInitializerMode = DynamicEmbInitializerMode.UNIFORM + mean: float = 0.0 + std_dev: float = 1.0 + lower: float = None + upper: float = None + value: float = 0.0 + + def __eq__(self, other): + if not isinstance(other, DynamicEmbInitializerArgs): + return NotImplementedError + if self.mode == DynamicEmbInitializerMode.NORMAL: + return self.mean == other.mean and self.std_dev == other.std_dev + elif self.mode == DynamicEmbInitializerMode.TRUNCATED_NORMAL: + return ( + self.mean == other.mean + and self.std_dev == other.std_dev + and self.lower == other.lower + and self.upper == other.upper + ) + elif self.mode == DynamicEmbInitializerMode.UNIFORM: + return self.lower == other.lower and self.upper == other.upper + elif self.mode == DynamicEmbInitializerMode.CONSTANT: + return self.value == other.value + return True + + def __ne__(self, other): + if not isinstance(other, DynamicEmbInitializerArgs): + return NotImplementedError + return not (self == other) + + def as_ctype(self) -> InitializerArgs: + return InitializerArgs( + self.mode.value, + self.mean, + self.std_dev, + self.lower if self.lower else 0.0, + self.upper if self.upper else 1.0, + self.value, + ) + + +TableOptionType = TypeVar("TableOptionType") +OptimizerInterface = TypeVar("OptimizerInterface") + +KEY_TYPE = torch.int64 +EMBEDDING_TYPE = torch.float32 +SCORE_TYPE = torch.int64 +OPT_STATE_TYPE = torch.float32 +COUNTER_TYPE = torch.int64 + +torch_dtype_to_np_dtype = { + torch.uint64: np.uint64, + torch.int64: np.int64, + torch.float32: np.float32, +} + + +# make it standalone to avoid recursive references. +class Storage(abc.ABC, Generic[TableOptionType, OptimizerInterface]): + @abc.abstractmethod + def __init__( + self, + options: TableOptionType, + optimizer: OptimizerInterface, + ): + pass + + @abc.abstractmethod + def find( + self, + unique_keys: torch.Tensor, + unique_vals: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + missing_scores: torch.Tensor + return num_missing, missing_keys, missing_indices, missing_scores + + @abc.abstractmethod + def find_embeddings( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: int + missing_keys: torch.Tensor + missing_indices: torch.Tensor + missing_scores: torch.Tensor + return num_missing, missing_keys, missing_indices, missing_scores + + @abc.abstractmethod + def insert( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: Optional[torch.Tensor] = None, + ) -> None: + pass + + @abc.abstractmethod + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def enable_update(self) -> bool: + ... + + @abc.abstractmethod + def dump( + self, + meta_file_path: str, + emb_key_path: str, + embedding_file_path: str, + score_file_path: Optional[str], + opt_file_path: Optional[str], + ) -> None: + pass + + @abc.abstractmethod + def load( + self, + meta_file_path: str, + emb_file_path: str, + embedding_file_path: str, + score_file_path: Optional[str], + opt_file_path: Optional[str], + include_optim: bool, + ) -> None: + pass + + @abc.abstractmethod + def embedding_dtype( + self, + ) -> torch.dtype: + pass + + @abc.abstractmethod + def embedding_dim( + self, + ) -> int: + pass + + @abc.abstractmethod + def value_dim( + self, + ) -> int: + pass + + @abc.abstractmethod + def init_optimizer_state( + self, + ) -> float: + pass + + +class Cache(abc.ABC): + @abc.abstractmethod + def find( + self, + unique_keys: torch.Tensor, + unique_vals: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: int + missing_keys: torch.Tensor + missing_indices: torch.Tensor + missing_scores: torch.Tensor + return num_missing, missing_keys, missing_indices, missing_scores + + @abc.abstractmethod + def find_embeddings( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: int + missing_keys: torch.Tensor + missing_indices: torch.Tensor + missing_scores: torch.Tensor + return num_missing, missing_keys, missing_indices, missing_scores + + @abc.abstractmethod + def find_missed_keys( + self, + unique_keys: torch.Tensor, + founds: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_missing: int + missing_keys: torch.Tensor + missing_indices: torch.Tensor + missing_scores: torch.Tensor + return num_missing, missing_keys, missing_indices, missing_scores + + @abc.abstractmethod + def insert_and_evict( + self, + keys: torch.Tensor, + values: torch.Tensor, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + num_evicted: int + evicted_keys: torch.Tensor + evicted_values: torch.Tensor + evicted_scores: torch.Tensor + return num_evicted, evicted_keys, evicted_values, evicted_scores + + @abc.abstractmethod + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[int, torch.Tensor, torch.Tensor]: + num_missing: int + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + @abc.abstractmethod + def flush(self, storage: Storage) -> None: + pass + + @abc.abstractmethod + def reset( + self, + ) -> None: + pass + + @abc.abstractmethod + def cache_metrics( + self, + ) -> torch.Tensor: + pass + + @abc.abstractmethod + def set_record_cache_metrics(self, record: bool) -> None: + pass + + +class Counter(abc.ABC): + """ + Interface of a counter table which maps a key to a counter. + """ + + @abc.abstractmethod + def add( + self, keys: torch.Tensor, frequencies: torch.Tensor, inplace: bool + ) -> torch.Tensor: + """ + Add keys with frequencies to the `Counter` and get accumulated counter of each key. + For not existed keys, the frequencies will be assigned directly. + For existing keys, the frequencies will be accumulated. + + Args: + keys (torch.Tensor): The input keys, should be unique keys. + frequencies (torch.Tensor): The input frequencies, serve as initial or incremental values of frequencies' states. + inplace: If true then store the accumulated_frequencies to counter. + + Returns: + accumulated_frequencies (torch.Tensor): the frequencies' state in the `Counter` for the input keys. + """ + accumulated_frequencies: torch.Tensor + return accumulated_frequencies + + @abc.abstractmethod + def erase(self, keys) -> None: + """ + Erase keys form the `Counter`. + + Args: + keys (torch.Tensor): The input keys to be erased. + """ + + @abc.abstractmethod + def memory_usage(self, mem_type=MemoryType.DEVICE) -> int: + """ + Get the consumption of a specific memory type. + + Args: + mem_type (MemoryType): the specific memory type, default to MemoryType.DEVICE. + """ + + @abc.abstractmethod + def load(self, key_file, counter_file) -> None: + """ + Load keys and frequencies from input file path. + + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ + + @abc.abstractmethod + def dump(self, key_file, counter_file) -> None: + """ + Dump keys and frequencies to output file path. + + Args: + key_file (str): the file path of keys. + counter_file (str): the file path of frequencies. + """ + + +class AdmissionStrategy(abc.ABC): + @abc.abstractmethod + def admit( + self, + keys: torch.Tensor, + frequencies: torch.Tensor, + ) -> torch.Tensor: + """ + Admit keys with frequencies >= threshold. + """ + + @abc.abstractmethod + def initialize_non_admitted_embeddings( + self, + buffer: torch.Tensor, + indices: torch.Tensor, + ) -> None: + """ + Initialize the embeddings for the keys that are not admitted. + """ diff --git a/corelib/dynamicemb/dynamicemb/utils.py b/corelib/dynamicemb/dynamicemb/utils.py index f44c87970..174d61134 100644 --- a/corelib/dynamicemb/dynamicemb/utils.py +++ b/corelib/dynamicemb/dynamicemb/utils.py @@ -13,10 +13,63 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +# pyre-strict +from typing import List, Optional, Set, Type, Union, cast import torch +# from torchrec.distributed import ModuleShardingPlan +from torchrec.modules.embedding_modules import ( + EmbeddingBagCollection, + EmbeddingCollection, +) + +TORCHREC_TYPES: Set[Type[Union[EmbeddingBagCollection, EmbeddingCollection]]] = { + EmbeddingBagCollection, + EmbeddingCollection, +} + + +def tabulate( + table: List[List[Union[str, int]]], + headers: Optional[List[str]] = None, + sub_headers: bool = False, +) -> str: + """ + Format a table as a string. + Parameters: + table (list of lists or list of tuples): The data to be formatted as a table. + headers (list of strings, optional): The column headers for the table. If not provided, the first row of the table will be used as the headers. + Returns: + str: A string representation of the table. + """ + if headers is None: + headers = table[0] + table = table[1:] + headers = cast(List[str], headers) + rows = [] + # Determine the maximum width of each column + col_widths = [max([len(str(item)) for item in column]) for column in zip(*table)] + col_widths = [max(i, len(j)) for i, j in zip(col_widths, headers)] + # Format each row of the table + for row in table: + row_str = " | ".join( + [str(item).ljust(width) for item, width in zip(row, col_widths)] + ) + rows.append(row_str) + # Add the header row and the separator line + rows.insert( + 0, + " | ".join( + [header.center(width) for header, width in zip(headers, col_widths)] + ), + ) + + rows.insert(1, " | ".join(["-" * width for width in col_widths])) + if sub_headers: + rows.insert(3, " | ".join(["-" * width for width in col_widths])) + return "\n".join(rows) + def assert_tensors_equal( tensor1: torch.Tensor, diff --git a/corelib/dynamicemb/example/README.md b/corelib/dynamicemb/example/README.md new file mode 100644 index 000000000..e8979691d --- /dev/null +++ b/corelib/dynamicemb/example/README.md @@ -0,0 +1,18 @@ +# Dynamicemb Example Introduction + +In short, **dynamicemb** provides distributed, high-performance dynamic embedding storage and related functions for training. + +How to run: +```shell +export NGPU=1 +bash ./run_example.sh +``` + +- The [example.py](./example.py) will show you how to train and evaluate the embedding module, as well as dump, load and incremental dump the module, and this example also demonstrates how to customize embedding admissions. + + +- For detailed explanations of specific APIs and parameters, please refer to [API Doc](../DynamicEmb_APIs.md). + +- For usage of external storage, Refer to demo `PyDictStorage` in [uint test](../test/test_batched_dynamic_embedding_tables_v2.py). + +***dynamicemb** supports not only `EmbeddingCollection` but also `EmbeddingBagCollection`. However, due to the requirements of generative recommendations, dynamicemb focuses on performance optimization of `EmbeddingCollection` while providing full functional support for `EmbeddingBagCollection`. And we use `EmbeddingCollection` as an example.* \ No newline at end of file diff --git a/corelib/dynamicemb/example/example.py b/corelib/dynamicemb/example/example.py index 96efca59a..04780f0e6 100644 --- a/corelib/dynamicemb/example/example.py +++ b/corelib/dynamicemb/example/example.py @@ -4,6 +4,7 @@ import os import shutil import urllib.request +import warnings import zipfile from typing import Dict, List @@ -19,20 +20,24 @@ DynamicEmbLoad, DynamicEmbScoreStrategy, DynamicEmbTableOptions, + FrequencyAdmissionStrategy, + KVCounter, ) +from dynamicemb.dynamicemb_config import data_type_to_dtype, get_optimizer_state_dim from dynamicemb.incremental_dump import get_score, incremental_dump +from dynamicemb.optimizer import EmbOptimType, convert_optimizer_type from dynamicemb.planner import ( DynamicEmbeddingEnumerator, DynamicEmbeddingShardingPlanner, DynamicEmbParameterConstraints, ) from dynamicemb.shard import DynamicEmbeddingCollectionSharder -from fbgemm_gpu.split_embedding_configs import EmbOptimType, SparseType +from fbgemm_gpu.split_embedding_configs import SparseType from torch.optim import Adam from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler from torchrec import DataType -from torchrec.distributed.comm import get_local_size +from torchrec.distributed.comm import get_local_rank, get_local_size from torchrec.distributed.fbgemm_qcomm_codec import ( CommType, QCommsConfig, @@ -43,18 +48,36 @@ from torchrec.distributed.planner.storage_reservations import ( HeuristicalStorageReservation, ) +from torchrec.distributed.planner.types import ShardingPlan from torchrec.distributed.types import ShardingType from torchrec.modules.embedding_configs import EmbeddingConfig from torchrec.modules.embedding_modules import EmbeddingCollection from torchrec.sparse.jagged_tensor import KeyedJaggedTensor +# Filter FBGEMM warning, make notebook clean +warnings.filterwarnings( + "ignore", message=".*torch.library.impl_abstract.*", category=FutureWarning +) + backend = "nccl" dist.init_process_group(backend=backend) + +# Set LOCAL_WORLD_SIZE if not available for proper topology configuration +if "LOCAL_WORLD_SIZE" not in os.environ: + os.environ["LOCAL_WORLD_SIZE"] = str(torch.cuda.device_count()) + +# Set LOCAL_RANK if not available (for consistency) +if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = str(get_local_rank()) + +# Set RANK if not available +if "RANK" not in os.environ: + os.environ["RANK"] = str(dist.get_rank()) + local_rank = dist.get_rank() # for one node world_size = dist.get_world_size() torch.cuda.set_device(local_rank) device = torch.device(f"cuda:{local_rank}") - # print with rank info original_print = builtins.print @@ -64,10 +87,11 @@ def rank_print(*args, **kwargs): builtins.print = rank_print +cache_ratio = 0.5 # assume we will use 50% of the HBM for cache def download_movielens(data_dir="./ml-1m"): - if local_rank == 0: + if dist.get_rank() == 0: # Use global rank for multi-node consistency os.makedirs(data_dir, exist_ok=True) if os.path.exists(os.path.join(data_dir, "ratings.dat")): print(f"MovieLens in {data_dir}") @@ -99,9 +123,13 @@ def download_movielens(data_dir="./ml-1m"): def parse_args(): parser = argparse.ArgumentParser(description="TorchRec MovieLens with dynamicemb") parser.add_argument("--train", action="store_true") + parser.add_argument("--eval", action="store_true") parser.add_argument("--load", action="store_true") parser.add_argument("--dump", action="store_true") parser.add_argument("--incremental_dump", action="store_true") + parser.add_argument("--caching", action="store_true") + parser.add_argument("--prefetch_pipeline", action="store_true") + parser.add_argument("--external_storage", action="store_true") parser.add_argument( "--data_path", @@ -133,6 +161,13 @@ def parse_args(): parser.add_argument( "--seed", type=int, default=42, help="random seed used for initialization" ) + # torchrun --standalone --nproc_per_node=${NGPU} example.py --train "$@" --admission_threshold 5 + parser.add_argument( + "--admission_threshold", + type=int, + default=0, + help="Frequency threshold for admission strategy (0 disable admission strategy, >0 enable admission strategy and only keys appearing >= threshold will be stored in tables)", + ) return parser.parse_args() @@ -341,8 +376,76 @@ def forward(self, kjt: KeyedJaggedTensor) -> torch.Tensor: return torch.sum(x.t(), dim=-1) +def get_sharder(args, optimizer_type): + # set optimizer args + learning_rate = args.lr + beta1 = 0.9 + beta2 = 0.999 + weight_decay = 0 + eps = 0.001 + + # Put args into a optimizer kwargs , which is same usage of torchrec + optimizer_kwargs = { + "optimizer": optimizer_type, + "learning_rate": learning_rate, + "beta1": beta1, + "beta2": beta2, + "weight_decay": weight_decay, + "eps": eps, + } + + fused_params = {} + fused_params[ + "output_dtype" + ] = ( + SparseType.FP32 + ) # data type of the output after lookup, and can differ from the stored. + fused_params.update(optimizer_kwargs) + fused_params[ + "prefetch_pipeline" + ] = args.prefetch_pipeline # whether enable prefetch for embedding lookup module + + # precision of all-to-all + qcomm_codecs_registry = ( + get_qcomm_codecs_registry( + qcomms_config=QCommsConfig( + # pyre-ignore + forward_precision=CommType.FP32, + # pyre-ignore + backward_precision=CommType.FP32, + ) + ) + if backend == "nccl" + else None + ) + + """ + fused_params: + items in fused_params will be finally passed to embedding lookup module. But before that: + logic tables in `EmbeddingCollection` will be divided into multiple groups in the `ShardedDynamicEmbeddingCollection`, + and the fused_params are equal for tables in the same group. + However, we only provide the common for all tables here, but some fields in `DynamicEmbTableOptions` will be merged into fused_params + and then be used to group tables(please refer DynamicEmbTableOptions for more details). + **Performance** issue: Embedding lookup within the same group can be executed in parallel, + while embedding lookup between different groups can only be executed sequentially. + use_index_dedup: + Unlike `EmbeddingBagCollection`, there is no reduction operation at the jagged dimension in the input `KeyedJaggedTensor` for `EmbeddingCollection`. + Therefore, we can deduplicate the input's indices in the input distributor before sparse feature's all-to-all, + then it will reduce the bandwidth pressure of NVLink or PCIe when perform embedding's all-to-all, and restore them using inverse information finally. + qcomm_codecs_registry: used to configure the embeddings(forward) or gradients(backward)' precision when perform all-to-all operation across different ranks + in distributed environment. + """ + return DynamicEmbeddingCollectionSharder( + qcomm_codecs_registry=qcomm_codecs_registry, + fused_params=fused_params, + use_index_dedup=True, + ) + + # use a function warp all the Planner code -def get_planner(device, eb_configs, batch_size): +def get_planner( + device, eb_configs, batch_size, optimizer_type, training, caching, args +): DATA_TYPE_NUM_BITS: Dict[DataType, int] = { DataType.FP32: 32, DataType.FP16: 16, @@ -353,12 +456,12 @@ def get_planner(device, eb_configs, batch_size): ddr_cap = 512 * 1024 * 1024 * 1024 # Assume a Node have 512GB memory intra_host_bw = 450e9 # Nvlink bandwidth inter_host_bw = 25e9 # NIC bandwidth + bucket_capacity = 1024 if caching else 128 dict_const = {} for eb_config in eb_configs: - # For HVK embedding table , need to calculate how many bytes of embedding vector store in GPU HBM - # In this case , we will put all the embedding vector into GPU HBM + # For HVK embedding table, need to calculate how many bytes of embedding vector store in GPU HBM dim = eb_config.embedding_dim tmp_type = eb_config.data_type @@ -367,19 +470,62 @@ def get_planner(device, eb_configs, batch_size): emb_num_embeddings_next_power_of_2 = 2 ** math.ceil( math.log2(emb_num_embeddings) ) # HKV need embedding vector num is power of 2 - total_hbm_need = embedding_type_bytes * dim * emb_num_embeddings_next_power_of_2 + threshold = (bucket_capacity * world_size) / cache_ratio + threshold_int = math.ceil(threshold) + if emb_num_embeddings_next_power_of_2 < threshold_int: + emb_num_embeddings_next_power_of_2 = 2 ** math.ceil( + math.log2(threshold_int) + ) + + # e.g. for adam, its `x`` embedding + `2x`` optimizer states + total_dim = dim + get_optimizer_state_dim( + convert_optimizer_type(optimizer_type), dim, data_type_to_dtype(tmp_type) + ) + total_hbm_need = ( + embedding_type_bytes * total_dim * emb_num_embeddings_next_power_of_2 + ) + + # Setup admission strategy if threshold > 0 + admit_strategy = None + admission_counter = None + if args.admission_threshold > 0: + print( + f"Admission strategy enabled with threshold={args.admission_threshold}" + ) + # Create counter to track key frequencies + admission_counter = KVCounter( + capacity=emb_num_embeddings_next_power_of_2, + bucket_capacity=bucket_capacity, + key_type=torch.int64, + device=device, + ) + + # Create admission strategy with threshold + admit_strategy = FrequencyAdmissionStrategy( + threshold=args.admission_threshold, + initializer_args=DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.CONSTANT, + value=0.0, # Initialize rejected keys to 0 + ), + ) const = DynamicEmbParameterConstraints( sharding_types=[ - ShardingType.ROW_WISE.value, + ShardingType.ROW_WISE.value, # dynamicemb embedding table only support to be sharded in row-wise. ], - use_dynamicemb=True, # from here , is all the HKV options , default use_dynamicemb is False , if it is False , it will fallback to raw TorchREC ParameterConstraints + use_dynamicemb=True, # indicate using dynamicemb, and will fallback to raw ParameterConstraints when Fale. dynamicemb_options=DynamicEmbTableOptions( - global_hbm_for_values=total_hbm_need, + global_hbm_for_values=total_hbm_need * cache_ratio + if caching + else total_hbm_need, initializer_args=DynamicEmbInitializerArgs( mode=DynamicEmbInitializerMode.NORMAL ), score_strategy=DynamicEmbScoreStrategy.STEP, + caching=caching, + training=training, + admit_strategy=admit_strategy, + admission_counter=admission_counter, ), ) @@ -390,18 +536,19 @@ def get_planner(device, eb_configs, batch_size): world_size=dist.get_world_size(), compute_device=device.type, hbm_cap=hbm_cap, - ddr_cap=ddr_cap, # For HVK , if we need to put embedding vector into Host memory , it is important set ddr capacity + ddr_cap=ddr_cap, intra_host_bw=intra_host_bw, inter_host_bw=inter_host_bw, ) - # Same usage of TorchREC's EmbeddingEnumerator + # same usage of torchrec's EmbeddingEnumerator enumerator = DynamicEmbeddingEnumerator( topology=topology, constraints=dict_const, ) - # Almost same usage of TorchREC's EmbeddingShardingPlanner , but we need to input eb_configs, so we can plan every GPU's HKV object. + # Almost same usage of torchrec's EmbeddingShardingPlanner, except to input eb_configs, + # as dynamicemb need EmbeddingConfig info to help to plan. return DynamicEmbeddingShardingPlanner( eb_configs=eb_configs, topology=topology, @@ -413,57 +560,57 @@ def get_planner(device, eb_configs, batch_size): ) -def apply_dmp(model, args): +def apply_dmp(model, args, training): + """ + The initialization of embedding lookup module in dynamicemb is almost consistent with torchrec. + 1. Firstly, you should configure the global parameters of an embedding table using `EmbeddingCollection`. + 2. Then, build a `DynamicEmbeddingCollectionSharder`, and generate `ShardingPlan` from `DynamicEmbeddingShardingPlanner`. + 3. Finally, pass all parameters to the `DistributedModelParallel`, which then handles the embedding sharding and initialization. + """ eb_configs = model.embedding_module.embedding_configs() - # set optimizer args - learning_rate = args.lr - beta1 = 0.9 - beta2 = 0.999 - weight_decay = 0 - eps = 0.001 - - # Put args into a optimizer kwargs , which is same usage of TorchREC - optimizer_kwargs = { - "optimizer": EmbOptimType.ADAM, - "learning_rate": learning_rate, - "beta1": beta1, - "beta2": beta2, - "weight_decay": weight_decay, - "eps": eps, - } - - fused_params = {} - fused_params["output_dtype"] = SparseType.FP32 - fused_params.update(optimizer_kwargs) - - # precision of all-to-all - qcomm_codecs_registry = ( - get_qcomm_codecs_registry( - qcomms_config=QCommsConfig( - # pyre-ignore - forward_precision=CommType.FP32, - # pyre-ignore - backward_precision=CommType.FP32, - ) - ) - if backend == "nccl" - else None + optimizer_type = EmbOptimType.ADAM + + """ + After configuring the `EmbeddingCollection`, you need to configure `DynamicEmbeddingCollectionSharder`. + It can create an instance of `ShardedDynamicEmbeddingCollection`. + `ShardedDynamicEmbeddingCollection` provides customized embedding lookup module base on + [HKV](https://github.com/NVIDIA-Merlin/HierarchicalKV), a GPU hash table which can utilize both device and host memory, + support automatic eviction based on score(per key) while provide a better performance. + Besides, due to differences in deduplication between hash tables and array based static tables, + `ShardedDynamicEmbeddingCollection` also provide customized input distributor to support deduplication when `use_index_dedup=True`. + The actual sharding operation occurs during the initialization of the `ShardedDynamicEmbeddingCollection`, + but the parameters used to initialize `DynamicEmbeddingCollectionSharder` will play a key role in the sharding process. + By the way, `DynamicEmbeddingCollectionSharder` inherits `EmbeddingCollectionSharder`, + and its main job is return an instance of `ShardedDynamicEmbeddingCollection`. + """ + sharder = get_sharder(args, optimizer_type) + + """ + The next step of preparation is to generate a `ParameterSharding` for each table, describe (configure) the sharding of a parameter. + For dynamic embedding table, `DynamicEmbParameterSharding` will be generated, which includes the parameters required from our embedding lookup module. + We will not expand `DynamicEmbParameterSharding` here. + The following steps demonstrate how to obtain `DynamicEmbParameterSharding` by `DynamicEmbeddingShardingPlanner`. + """ + planner = get_planner( + device, + eb_configs, + args.batch_size, + optimizer_type=optimizer_type, + training=training, + caching=args.caching, + args=args, ) - - # Create a sharder , same usage with TorchREC , but need Use DynamicEmb function, because for index_dedup - # DynamicEmb overload this process to fit HKV - - sharder = DynamicEmbeddingCollectionSharder( - qcomm_codecs_registry=qcomm_codecs_registry, - fused_params=fused_params, - use_index_dedup=True, + # get plan for all ranks. + # ShardingPlan is a dict, mapping table name to ParameterSharding/DynamicEmbParameterSharding. + plan: ShardingPlan = planner.collective_plan( + model, [sharder], dist.GroupMember.WORLD ) - planner = get_planner(device, eb_configs, args.batch_size) - # Same usage of TorchREC - plan = planner.collective_plan(model, [sharder], dist.GroupMember.WORLD) - - # Same usage of TorchREC + """ + The final step is to input the `sharder` and `ShardingPlan` to the `DistributedModelParallel`, + who will implement the sharded plan through `sharder` and hold the `ShardedDynamicEmbeddingCollection` after sharding. + Then you can use `dmp` for **training** and **evaluation**, just like using `EmbeddingCollection`. + """ dmp = DistributedModelParallel( module=model, device=device, @@ -474,13 +621,17 @@ def apply_dmp(model, args): return dmp -def create_model(args): +def create_model(args, training=True): + # Define the configuration parameters for the embedding table, + # including its name, embedding dimension, total number of embeddings, and feature name. eb_configs = [ EmbeddingConfig( name="user_id", embedding_dim=args.embedding_dim, - num_embeddings=args.num_embeddings, # sum for all ranks. - feature_names=["user_id"], + num_embeddings=args.num_embeddings, # `num_embeddings` in `EmbeddingConfig` is the sum of all slices on all GPUs for a table. + feature_names=[ + "user_id" + ], # a list, means different features can share the same table data_type=DataType.FP32, # weight or embedding's data type. ), EmbeddingConfig( @@ -515,6 +666,10 @@ def create_model(args): ), ] + """ + `EmbeddingCollection` is a collection of multiple logical tables. + It does not allocate memory for embedding tables(device is "meta"). + """ ec = EmbeddingCollection( tables=eb_configs, device=torch.device("meta"), # set device to 'meta @@ -529,7 +684,7 @@ def create_model(args): over_arch_layer_sizes=mlp_dims, ) - model = apply_dmp(model, args) + model = apply_dmp(model, args, training) return model @@ -563,7 +718,7 @@ def train_one_epoch(model, train_loader, optimizer, loss_fn, epoch, total_epochs def test_one_epoch(model, test_loader, loss_fn, epoch, total_epochs): model.eval() test_loss = 0 - with torch.no_grad(): + with torch.inference_mode(): for features, labels in test_loader: features = features.to(device) labels = labels.to(device) @@ -580,10 +735,10 @@ def train(args): train_dataset = MovieLensDataset(args.data_path, split="train") test_dataset = MovieLensDataset(args.data_path, split="test") train_sampler = DistributedSampler( - train_dataset, num_replicas=world_size, rank=local_rank, shuffle=True + train_dataset, num_replicas=world_size, rank=dist.get_rank(), shuffle=True ) test_sampler = DistributedSampler( - test_dataset, num_replicas=world_size, rank=local_rank, shuffle=False + test_dataset, num_replicas=world_size, rank=dist.get_rank(), shuffle=False ) train_loader = DataLoader( @@ -619,8 +774,9 @@ def train(args): def dump(args): os.makedirs(args.save_dir, exist_ok=True) train_dataset = MovieLensDataset(args.data_path, split="train") + # Use global rank for proper data distribution across all processes train_sampler = DistributedSampler( - train_dataset, num_replicas=world_size, rank=local_rank, shuffle=True + train_dataset, num_replicas=world_size, rank=dist.get_rank(), shuffle=True ) train_loader = DataLoader( @@ -648,17 +804,19 @@ def dump(args): "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), }, - os.path.join(args.save_dir, f"model_epoch_{epoch+1}_rank{local_rank}.pt"), + os.path.join( + args.save_dir, f"model_epoch_{epoch+1}_rank{dist.get_rank()}.pt" + ), ) - # rank0 will gether embedding from other ranks, so no need to identify rank info. DynamicEmbDump(os.path.join(args.save_dir, "dynamicemb"), model, optim=True) def load(args): os.makedirs(args.save_dir, exist_ok=True) test_dataset = MovieLensDataset(args.data_path, split="test") + # Use global rank for proper data distribution across all processes test_sampler = DistributedSampler( - test_dataset, num_replicas=world_size, rank=local_rank, shuffle=False + test_dataset, num_replicas=world_size, rank=dist.get_rank(), shuffle=False ) test_loader = DataLoader( @@ -678,29 +836,35 @@ def load(args): # load checkpoint = torch.load( - os.path.join(args.save_dir, f"model_epoch_{args.epochs}_rank{local_rank}.pt"), + os.path.join( + args.save_dir, f"model_epoch_{args.epochs}_rank{dist.get_rank()}.pt" + ), weights_only=True, ) # Must set strict to False, as there is no embedding's weight in model.state_dict() model.load_state_dict(checkpoint["model_state_dict"], strict=False) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - # all rank will load from the same files. DynamicEmbLoad(os.path.join(args.save_dir, "dynamicemb"), model, optim=True) test_one_epoch(model, test_loader, criterion, 0, 1) dist.barrier(device_ids=[local_rank]) - if local_rank == 0: - shutil.rmtree(args.save_dir) + # Only global rank 0 should clean up, not local rank 0 on each node + if dist.get_rank() == 0: + try: + shutil.rmtree(args.save_dir) + except Exception as e: + print(f"Warning: Failed to remove {args.save_dir}: {e}") dist.barrier(device_ids=[local_rank]) def inc_dump(args): os.makedirs(args.save_dir, exist_ok=True) train_dataset = MovieLensDataset(args.data_path, split="train") + # Use global rank for proper data distribution across all processes train_sampler = DistributedSampler( - train_dataset, num_replicas=world_size, rank=local_rank, shuffle=True + train_dataset, num_replicas=world_size, rank=dist.get_rank(), shuffle=True ) train_loader = DataLoader( @@ -762,7 +926,7 @@ def main(): args = parse_args() torch.cuda.manual_seed(args.seed) np.random.seed(args.seed) - if local_rank == 0: + if dist.get_rank() == 0: # Use global rank for multi-node consistency download_movielens(args.data_path) dist.barrier(device_ids=[local_rank]) if args.train: diff --git a/corelib/dynamicemb/setup.py b/corelib/dynamicemb/setup.py index 39e42e512..5767d400f 100644 --- a/corelib/dynamicemb/setup.py +++ b/corelib/dynamicemb/setup.py @@ -16,6 +16,7 @@ import os import re import subprocess +import sys from pathlib import Path from setuptools import find_packages, setup @@ -24,6 +25,18 @@ subprocess.run( ["git", "submodule", "update", "--init", "../../third_party/HierarchicalKV"] ) +subprocess.run( + [ + sys.executable, + "-m", + "pip", + "uninstall", + "-y", + "dynamicemb", + "--break-system-packages", + ] +) +subprocess.run(["pip", "install", "ordered-set", "--break-system-packages"]) # TODO: update when torchrec release compatible commit. compatible_versions = "1.1.0" diff --git a/corelib/dynamicemb/src/dynamic_emb_op.cu b/corelib/dynamicemb/src/dynamic_emb_op.cu index a2eca9325..407bef01d 100644 --- a/corelib/dynamicemb/src/dynamic_emb_op.cu +++ b/corelib/dynamicemb/src/dynamic_emb_op.cu @@ -28,6 +28,7 @@ #include "index_calculation.h" #include "lookup_backward.h" #include "lookup_forward.h" +#include "lookup_kernel.cuh" #include "torch_utils.h" #include "unique_op.h" #include "utils.h" @@ -158,7 +159,36 @@ void insert_and_evict( reinterpret_cast(d_evicted_counter.data_ptr()), stream, unique_key, ignore_evict_strategy); } } +void insert_and_evict_with_scores( + std::shared_ptr table, + const size_t n, + const at::Tensor keys, + const at::Tensor values, + at::Tensor evicted_keys, + at::Tensor evicted_values, + at::Tensor evicted_score, + at::Tensor d_evicted_counter, + bool unique_key = true, + bool ignore_evict_strategy = false, + const std::optional scores = std::nullopt +) { +if (not scores.has_value() and (table->evict_strategy() == EvictStrategy::kCustomized || table->evict_strategy() == EvictStrategy::kLfu)) { + throw std::invalid_argument("Must specify the score when evict strategy is customized or LFU."); +} +auto stream = at::cuda::getCurrentCUDAStream().stream(); +if (table->evict_strategy() == EvictStrategy::kCustomized || table->evict_strategy() == EvictStrategy::kLfu) { + table->insert_and_evict( + n, keys.data_ptr(), values.data_ptr(), scores.value().data_ptr(), + evicted_keys.data_ptr(), evicted_values.data_ptr(), evicted_score.data_ptr(), + reinterpret_cast(d_evicted_counter.data_ptr()), stream, unique_key, ignore_evict_strategy); +} else { + table->insert_and_evict( + n, keys.data_ptr(), values.data_ptr(), nullptr, + evicted_keys.data_ptr(), evicted_values.data_ptr(), evicted_score.data_ptr(), + reinterpret_cast(d_evicted_counter.data_ptr()), stream, unique_key, ignore_evict_strategy); +} +} void accum_or_assign(std::shared_ptr table, const size_t n, const at::Tensor keys, const at::Tensor value_or_deltas, @@ -179,6 +209,26 @@ void accum_or_assign(std::shared_ptr table, } } + +void find_and_initialize( + std::shared_ptr table, + const size_t n, + const at::Tensor keys, + const at::Tensor values, + std::optional initializer_args) { + + if (n == 0) return; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + at::Tensor vals_ptr_tensor = at::empty({static_cast(n)}, + at::TensorOptions().dtype(at::kLong).device(values.device())); + auto vals_ptr = reinterpret_cast(vals_ptr_tensor.data_ptr()); + at::Tensor founds_tensor = at::empty({static_cast(n)}, + at::TensorOptions().dtype(at::kBool).device(keys.device())); + auto founds = founds_tensor.data_ptr(); + + table->find_and_initialize(n, keys.data_ptr(), vals_ptr, values.data_ptr(), founds, initializer_args, stream); +} + void find_or_insert(std::shared_ptr table, const size_t n, const at::Tensor keys, @@ -253,7 +303,9 @@ void find_pointers( const size_t n, const at::Tensor keys, at::Tensor values, - at::Tensor founds) { + at::Tensor founds, + const std::optional score = std::nullopt +) { if (n == 0) return; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -262,6 +314,53 @@ void find_pointers( table->find_pointers(n, keys.data_ptr(), values_data_ptr, found_tensor_data_ptr, nullptr, stream); + + // update score. + if (score.has_value()) { + at::Tensor locked_ptr = at::empty({static_cast(n)}, keys.options().dtype(at::kLong)); + at::Tensor success = at::empty({static_cast(n)}, keys.options().dtype(at::kBool)); + if (table->evict_strategy() == EvictStrategy::kCustomized || table->evict_strategy() == EvictStrategy::kLfu) { + auto&& option = at::TensorOptions().dtype(at::kUInt64).device(keys.device()); + // broadcast scores + at::Tensor bc_scores = at::empty({static_cast(n)}, option); + bc_scores.fill_(score.value()); + table->lock(n, keys.data_ptr(), reinterpret_cast(locked_ptr.data_ptr()), + success.data_ptr(), bc_scores.data_ptr(), stream); + } else { + table->lock(n, keys.data_ptr(), reinterpret_cast(locked_ptr.data_ptr()), + success.data_ptr(), nullptr, stream); + } + AT_CUDA_CHECK(cudaGetLastError()); + table->unlock(n, reinterpret_cast(locked_ptr.data_ptr()), keys.data_ptr(), success.data_ptr(), stream); + } +} + +void find_pointers_with_scores( + std::shared_ptr table, + const size_t n, + const at::Tensor keys, + at::Tensor values, + at::Tensor founds, + const std::optional &scores = std::nullopt +) { + + if (n == 0) return; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto values_data_ptr = reinterpret_cast(values.data_ptr()); + auto found_tensor_data_ptr = founds.data_ptr(); + + // update score. + if (scores.has_value()) { + if (table->evict_strategy() == EvictStrategy::kCustomized || table->evict_strategy() == EvictStrategy::kLfu) { + table->find_pointers(n, keys.data_ptr(), values_data_ptr, found_tensor_data_ptr, scores.value().data_ptr(), stream); + } + else { + table->find_pointers(n, keys.data_ptr(), values_data_ptr, found_tensor_data_ptr, nullptr, stream); + } + } else { + std::shared_ptr const_table = table; + const_table->find_pointers(n, keys.data_ptr(), values_data_ptr, found_tensor_data_ptr, nullptr, stream); + } } void assign(std::shared_ptr table, const size_t n, @@ -354,6 +453,45 @@ void export_batch_matched( keys.data_ptr(), values.data_ptr(), nullptr, stream); } +template +__global__ void compact_offsets( + const scalar_t *offsets, + scalar_t *features_offsets, + const int64_t num_features, + const int64_t batch_size +) { + for (int tid = threadIdx.x + blockIdx.x * blockDim.x; tid < num_features; tid += blockDim.x * gridDim.x) { + features_offsets[tid] = offsets[tid * batch_size]; + } + if (threadIdx.x == 0) { + features_offsets[num_features] = offsets[num_features * batch_size]; + } +} + +std::vector offsets_to_table_features_offsets(const at::Tensor &offsets, const std::vector &table_offsets_in_feature, const int64_t batch_size, cudaStream_t stream) { + int64_t table_num = table_offsets_in_feature.size() - 1; + int64_t num_features = (offsets.numel() - 1) / batch_size; + at::Tensor h_features_offsets = + at::empty({num_features + 1}, offsets.options().device(at::kCPU).pinned_memory(true)); + if (num_features == 0) { + return {0, 0}; + } + AT_DISPATCH_INTEGRAL_TYPES(offsets.scalar_type(), "compact_offsets", [&] { + compact_offsets<<>>( + offsets.data_ptr(), + h_features_offsets.data_ptr(), + num_features, + batch_size + ); + }); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + std::vector table_features_offsets(table_offsets_in_feature.size(), 0); + for (int i = 0; i < table_offsets_in_feature.size(); ++i) { + table_features_offsets[i] = h_features_offsets[table_offsets_in_feature[i]].item(); + } + return table_features_offsets; +} + void lookup_forward_dense( std::vector> tables, const at::Tensor indices, const at::Tensor offsets, const py::list scores, @@ -378,18 +516,14 @@ void lookup_forward_dense( } auto stream = at::cuda::getCurrentCUDAStream().stream(); - int64_t indices_shape = indices.size(0); + int64_t indices_shape = indices.numel(); auto unique_num_type = scalartype_to_datatype( convertTypeMetaToScalarType(d_unique_nums.dtype())); auto unique_offset_type = scalartype_to_datatype( convertTypeMetaToScalarType(d_unique_offsets.dtype())); - at::Tensor h_offset = - at::empty_like(offsets, offsets.options().device(at::kCPU)); - AT_CUDA_CHECK(cudaMemcpyAsync(h_offset.data_ptr(), offsets.data_ptr(), - offsets.numel() * offsets.element_size(), - cudaMemcpyDeviceToHost, stream)); - + auto h_table_offsets = offsets_to_table_features_offsets(offsets, table_offsets_in_feature, batch_size, stream); + size_t unique_op_capacity = unique_op->get_capacity(); if (indices_shape * 2 > unique_op_capacity) { at::Tensor new_keys = at::empty({indices_shape * 2}, indices.options()); @@ -404,21 +538,10 @@ void lookup_forward_dense( tmp_unique_indices[i] = at::empty_like(indices); } - at::Tensor h_table_offsets = - at::empty({table_num + 1}, table_offsets.options().device(at::kCPU)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - - h_table_offsets[0] = 0; for (int i = 0; i < table_num; ++i) { - int table_offset_begin = table_offsets_in_feature[i]; - int table_offset_end = table_offsets_in_feature[i + 1]; - int offset_begin = table_offset_begin * batch_size; - int offset_end = table_offset_end * batch_size; - - int64_t indices_begin = h_offset[offset_begin].item(); - int64_t indices_end = h_offset[offset_end].item(); + int64_t indices_begin = h_table_offsets[i]; + int64_t indices_end = h_table_offsets[i + 1]; int64_t indices_length = indices_end - indices_begin; - h_table_offsets[i + 1] = indices_end; if (indices_length == 0) { DEMB_CUDA_CHECK(cudaMemsetAsync( @@ -451,7 +574,7 @@ void lookup_forward_dense( cudaMemcpyDeviceToHost, stream)); AT_CUDA_CHECK(cudaStreamSynchronize(stream)); AT_CUDA_CHECK( - cudaMemcpyAsync(table_offsets.data_ptr(), h_table_offsets.data_ptr(), + cudaMemcpyAsync(table_offsets.data_ptr(), h_table_offsets.data(), table_offsets.numel() * table_offsets.element_size(), cudaMemcpyHostToDevice, stream)); @@ -463,7 +586,7 @@ void lookup_forward_dense( create_sub_tensor(unique_embs, unique_embs_offset * dim); auto score = std::make_optional(py::cast(scores[i])); find_or_insert(tables[i], tmp_unique_num, tmp_unique_indices[i], - tmp_unique_embs, score); + tmp_unique_embs, score); if (use_index_dedup) { void *dst_ptr = reinterpret_cast(unique_idx.data_ptr()) + unique_embs_offset * unique_idx.element_size(); @@ -487,54 +610,41 @@ void lookup_forward_dense( dst_type, offset_type, device_num_sms, stream); } -void lookup_forward_dense( +at::Tensor lookup_forward_dense_eval( std::vector> tables, - const at::Tensor indices, const at::Tensor offsets, - const std::vector &table_offsets_in_feature, int table_num, - int batch_size, int dim, const at::Tensor h_unique_offsets, - const at::Tensor unique_embs, const at::Tensor output_embs) { - - if (!offsets.is_cuda() || !indices.is_cuda()) { + const at::Tensor &indices, + const at::Tensor &offsets, + const std::vector &table_offsets_in_feature, + at::ScalarType embedding_dtype, + int table_num, + int batch_size, + int dim, + const at::Device& device, + const std::vector &eval_initializers) { + + if (!indices.is_cuda() || !offsets.is_cuda()) { throw std::runtime_error( "offsets or indices tensor must be on CUDA device"); } auto stream = at::cuda::getCurrentCUDAStream().stream(); - int64_t indices_shape = indices.size(0); - auto scalar_type = unique_embs.dtype().toScalarType(); - auto emb_dtype = scalartype_to_datatype(scalar_type); - scalar_type = output_embs.dtype().toScalarType(); - auto output_dtype = scalartype_to_datatype(scalar_type); - auto &device_prop = DeviceProp::getDeviceProp(indices.device().index()); - - at::Tensor h_offset = - at::empty_like(offsets, offsets.options().device(at::kCPU)); - AT_CUDA_CHECK(cudaMemcpyAsync(h_offset.data_ptr(), offsets.data_ptr(), - offsets.numel() * offsets.element_size(), - cudaMemcpyDeviceToHost, stream)); - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + int64_t num_indices = indices.numel(); - h_unique_offsets[0] = 0; - for (int i = 0; i < table_num; ++i) { - int table_offset_begin = table_offsets_in_feature[i]; - int table_offset_end = table_offsets_in_feature[i + 1]; - int offset_begin = table_offset_begin * batch_size; - int offset_end = table_offset_end * batch_size; + at::Tensor output_embs = at::empty({num_indices, dim}, at::TensorOptions().dtype(embedding_dtype).device(device)); - int64_t indices_begin = h_offset[offset_begin].item(); - int64_t indices_end = h_offset[offset_end].item(); - int64_t indices_length = indices_end - indices_begin; - h_unique_offsets[i + 1] = indices_end; - at::Tensor tmp_indices = create_sub_tensor(indices, indices_begin); - at::Tensor tmp_unique_embs = - create_sub_tensor(unique_embs, indices_begin * dim); - find_or_insert(tables[i], indices_length, tmp_indices, tmp_unique_embs); - at::Tensor tmp_output_embs = - create_sub_tensor(output_embs, indices_begin * dim); - dyn_emb::batched_vector_copy_device( - tmp_unique_embs.data_ptr(), output_embs.data_ptr(), indices_length, dim, - emb_dtype, output_dtype, device_prop.num_sms, stream); + auto table_features_offsets = offsets_to_table_features_offsets(offsets, table_offsets_in_feature, batch_size, stream); + + for (int i = 0; i < table_num; ++i) { + int64_t table_offset_begin = table_features_offsets[i]; + int64_t table_offset_end = table_features_offsets[i + 1]; + int64_t table_offset_length = table_offset_end - table_offset_begin; + at::Tensor current_indices = create_sub_tensor(indices, table_offset_begin); + at::Tensor current_output_embs = create_sub_tensor(output_embs, table_offset_begin * dim); + + find_and_initialize(tables[i], static_cast(table_offset_length), current_indices, current_output_embs, eval_initializers[i]); } + + return output_embs; } void lookup_backward_dense(const at::Tensor indices, const at::Tensor grads, @@ -588,6 +698,18 @@ void lookup_backward_dense(const at::Tensor indices, const at::Tensor grads, unique_key_ids, stream); } +std::tuple +reduce_grads(at::Tensor indices, at::Tensor grads, at::Tensor segment_range, at::Tensor h_segment_range) { + int64_t num_total = indices.size(0); + int64_t dim = grads.size(1); + int64_t num_segment = h_segment_range.size(0) - 1; + int64_t num_unique_total = h_segment_range[num_segment].item(); + at::Tensor unique_indices = at::empty(num_unique_total, indices.options()); + at::Tensor unique_grads = at::empty({num_unique_total, dim}, grads.options()); + lookup_backward_dense(indices, grads, dim, segment_range, unique_indices, unique_grads); + return std::make_tuple(unique_indices, unique_grads); +} + void lookup_backward_dense_dedup(const at::Tensor grads, at::Tensor unique_indices, at::Tensor reverse_idx, int32_t dim, @@ -616,110 +738,117 @@ void lookup_backward_dense_dedup(const at::Tensor grads, } void dedup_input_indices( - const at::Tensor indices, const at::Tensor offsets, - const at::Tensor h_table_offsets_in_feature, - const at::Tensor d_table_offsets_in_feature, int table_num, - int local_batch_size, const at::Tensor reverse_idx, - const at::Tensor h_unique_nums, const at::Tensor d_unique_nums, - const at::Tensor h_unique_offsets, const at::Tensor d_unique_offsets, - std::vector unique_idx, const at::Tensor new_offsets, - const at::Tensor new_lengths, int device_num_sms, - std::shared_ptr unique_op) { - - if (!offsets.is_cuda() || !indices.is_cuda()) { - throw std::runtime_error( - "offsets or indices tensor must be on CUDA device"); - } - - // Check dtype of h_unique_nums and d_unique_nums - if (h_unique_nums.scalar_type() != at::kUInt64 || - d_unique_nums.scalar_type() != at::kUInt64) { - throw std::runtime_error( - "h_unique_nums and d_unique_nums must have dtype uint64_t"); - } - - auto stream = at::cuda::getCurrentCUDAStream().stream(); - int64_t indices_shape = indices.size(0); - auto unique_num_type = scalartype_to_datatype( - convertTypeMetaToScalarType(d_unique_nums.dtype())); - auto unique_offset_type = scalartype_to_datatype( - convertTypeMetaToScalarType(d_unique_offsets.dtype())); - int64_t new_lengths_size = new_lengths.size(0); + const at::Tensor indices, const at::Tensor offsets, + const at::Tensor h_table_offsets_in_feature, + const at::Tensor d_table_offsets_in_feature, int table_num, + int local_batch_size, const at::Tensor reverse_idx, + const at::Tensor h_unique_nums, const at::Tensor d_unique_nums, + const at::Tensor h_unique_offsets, const at::Tensor d_unique_offsets, + std::vector unique_idx, const at::Tensor new_offsets, + const at::Tensor new_lengths, int device_num_sms, + std::shared_ptr unique_op, + const c10::optional< std::vector > &frequency_counters = c10::nullopt, + const c10::optional &input_frequencies = c10::nullopt + ) { + - at::Tensor h_offset = - at::empty_like(offsets, offsets.options().device(at::kCPU)); - AT_CUDA_CHECK(cudaMemcpyAsync(h_offset.data_ptr(), offsets.data_ptr(), - offsets.numel() * offsets.element_size(), - cudaMemcpyDeviceToHost, stream)); +if (!offsets.is_cuda() || !indices.is_cuda()) { + throw std::runtime_error( + "offsets or indices tensor must be on CUDA device"); +} - size_t unique_op_capacity = unique_op->get_capacity(); - if (indices_shape * 2 > unique_op_capacity) { - at::Tensor new_keys = at::empty({indices_shape * 2}, indices.options()); - at::Tensor new_vals = at::empty( - {indices_shape * 2}, - at::TensorOptions().dtype(at::kUInt64).device(indices.device())); - unique_op->reset_capacity(new_keys, new_vals, indices_shape * 2, stream); - } +// Check dtype of h_unique_nums and d_unique_nums +if (h_unique_nums.scalar_type() != at::kUInt64 || + d_unique_nums.scalar_type() != at::kUInt64) { + throw std::runtime_error( + "h_unique_nums and d_unique_nums must have dtype uint64_t"); +} - std::vector tmp_unique_indices(table_num); - for (int i = 0; i < table_num; ++i) { - tmp_unique_indices[i] = at::empty_like(indices); - } +auto stream = at::cuda::getCurrentCUDAStream().stream(); +int64_t indices_shape = indices.size(0); +auto unique_num_type = scalartype_to_datatype( + convertTypeMetaToScalarType(d_unique_nums.dtype())); +auto unique_offset_type = scalartype_to_datatype( + convertTypeMetaToScalarType(d_unique_offsets.dtype())); +int64_t new_lengths_size = new_lengths.size(0); + +at::Tensor h_offset = + at::empty_like(offsets, offsets.options().device(at::kCPU)); +AT_CUDA_CHECK(cudaMemcpyAsync(h_offset.data_ptr(), offsets.data_ptr(), + offsets.numel() * offsets.element_size(), + cudaMemcpyDeviceToHost, stream)); + +size_t unique_op_capacity = unique_op->get_capacity(); +if (indices_shape * 2 > unique_op_capacity) { + at::Tensor new_keys = at::empty({indices_shape * 2}, indices.options()); + at::Tensor new_vals = at::empty( + {indices_shape * 2}, + at::TensorOptions().dtype(at::kUInt64).device(indices.device())); + unique_op->reset_capacity(new_keys, new_vals, indices_shape * 2, stream); +} - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); +std::vector tmp_unique_indices(table_num); +for (int i = 0; i < table_num; ++i) { + tmp_unique_indices[i] = at::empty_like(indices); +} - for (int i = 0; i < table_num; ++i) { - int table_offset_begin = h_table_offsets_in_feature[i].item(); - int table_offset_end = h_table_offsets_in_feature[i + 1].item(); - int offset_begin = table_offset_begin * local_batch_size; - int offset_end = table_offset_end * local_batch_size; +AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - int64_t indices_begin = h_offset[offset_begin].item(); - int64_t indices_end = h_offset[offset_end].item(); - int64_t indices_length = indices_end - indices_begin; +for (int i = 0; i < table_num; ++i) { + int table_offset_begin = h_table_offsets_in_feature[i].item(); + int table_offset_end = h_table_offsets_in_feature[i + 1].item(); + int offset_begin = table_offset_begin * local_batch_size; + int offset_end = table_offset_end * local_batch_size; - if (indices_length == 0) { - DEMB_CUDA_CHECK(cudaMemsetAsync( - reinterpret_cast(d_unique_nums.data_ptr()) + i, 0, - sizeof(uint64_t), stream)); - dyn_emb::add_offset(d_unique_nums.data_ptr(), d_unique_offsets.data_ptr(), - i, unique_num_type, unique_offset_type, stream); - } else { - at::Tensor tmp_indices = create_sub_tensor(indices, indices_begin); - at::Tensor tmp_reverse_idx = - create_sub_tensor(reverse_idx, indices_begin); - at::Tensor tmp_d_unique_num = create_sub_tensor(d_unique_nums, i); - at::Tensor previous_d_unique_num = create_sub_tensor(d_unique_offsets, i); + int64_t indices_begin = h_offset[offset_begin].item(); + int64_t indices_end = h_offset[offset_end].item(); + int64_t indices_length = indices_end - indices_begin; - unique_op->unique(tmp_indices, indices_length, tmp_reverse_idx, - unique_idx[i], tmp_d_unique_num, stream, - previous_d_unique_num); - dyn_emb::add_offset(d_unique_nums.data_ptr(), d_unique_offsets.data_ptr(), - i, unique_num_type, unique_offset_type, stream); - } + if (indices_length == 0) { + DEMB_CUDA_CHECK(cudaMemsetAsync( + reinterpret_cast(d_unique_nums.data_ptr()) + i, 0, + sizeof(uint64_t), stream)); + dyn_emb::add_offset(d_unique_nums.data_ptr(), d_unique_offsets.data_ptr(), + i, unique_num_type, unique_offset_type, stream); + } else { + at::Tensor tmp_indices = create_sub_tensor(indices, indices_begin); + at::Tensor tmp_reverse_idx = + create_sub_tensor(reverse_idx, indices_begin); + at::Tensor tmp_d_unique_num = create_sub_tensor(d_unique_nums, i); + at::Tensor previous_d_unique_num = create_sub_tensor(d_unique_offsets, i); + + // For first stage deduplication, we don't have input frequencies (set to empty tensor) + // The unique operation will default each key's frequency to 1 + at::Tensor freq_counter = frequency_counters.has_value() ? frequency_counters.value()[i] : at::Tensor(); + unique_op->unique(tmp_indices, indices_length, tmp_reverse_idx, + unique_idx[i], tmp_d_unique_num, stream, + previous_d_unique_num, freq_counter); + dyn_emb::add_offset(d_unique_nums.data_ptr(), d_unique_offsets.data_ptr(), + i, unique_num_type, unique_offset_type, stream); } +} - AT_CUDA_CHECK( - cudaMemcpyAsync(h_unique_nums.data_ptr(), d_unique_nums.data_ptr(), - d_unique_nums.numel() * d_unique_nums.element_size(), - cudaMemcpyDeviceToHost, stream)); - AT_CUDA_CHECK(cudaMemcpyAsync( - h_unique_offsets.data_ptr(), d_unique_offsets.data_ptr(), - d_unique_offsets.numel() * d_unique_offsets.element_size(), - cudaMemcpyDeviceToHost, stream)); - - AT_CUDA_CHECK(cudaStreamSynchronize(stream)); - - auto offset_type = - scalartype_to_datatype(convertTypeMetaToScalarType(new_offsets.dtype())); - auto lengths_type = - scalartype_to_datatype(convertTypeMetaToScalarType(new_lengths.dtype())); - - get_new_length_and_offsets( - reinterpret_cast(d_unique_offsets.data_ptr()), - d_table_offsets_in_feature.data_ptr(), table_num, - new_lengths_size, local_batch_size, lengths_type, offset_type, - new_offsets.data_ptr(), new_lengths.data_ptr(), stream); +AT_CUDA_CHECK( + cudaMemcpyAsync(h_unique_nums.data_ptr(), d_unique_nums.data_ptr(), + d_unique_nums.numel() * d_unique_nums.element_size(), + cudaMemcpyDeviceToHost, stream)); +AT_CUDA_CHECK(cudaMemcpyAsync( + h_unique_offsets.data_ptr(), d_unique_offsets.data_ptr(), + d_unique_offsets.numel() * d_unique_offsets.element_size(), + cudaMemcpyDeviceToHost, stream)); + +AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + +auto offset_type = + scalartype_to_datatype(convertTypeMetaToScalarType(new_offsets.dtype())); +auto lengths_type = + scalartype_to_datatype(convertTypeMetaToScalarType(new_lengths.dtype())); + +get_new_length_and_offsets( + reinterpret_cast(d_unique_offsets.data_ptr()), + d_table_offsets_in_feature.data_ptr(), table_num, + new_lengths_size, local_batch_size, lengths_type, offset_type, + new_offsets.data_ptr(), new_lengths.data_ptr(), stream); } void lookup_forward(const at::Tensor src, const at::Tensor dst, @@ -752,7 +881,7 @@ void lookup_backward(const at::Tensor grad, const at::Tensor unique_buffer, const at::Tensor inverse_indices, const at::Tensor biased_offsets, const int dim, const int table_num, int batch_size, int feature_num, - int num_key) { + int num_key, int combiner) { auto stream = at::cuda::getCurrentCUDAStream().stream(); auto value_type = scalartype_to_datatype( @@ -762,7 +891,98 @@ void lookup_backward(const at::Tensor grad, const at::Tensor unique_buffer, dyn_emb::backward(grad.data_ptr(), unique_buffer.data_ptr(), unique_indices.data_ptr(), inverse_indices.data_ptr(), biased_offsets.data_ptr(), dim, batch_size, feature_num, - num_key, key_type, value_type, stream); + num_key, combiner, key_type, value_type, stream); +} + +template +__global__ void load_from_pointers_kernel_vec4( + int batch, + int emb_dim, + T* __restrict__ outputs, + T* const * __restrict__ src_ptrs) { + + constexpr int kWarpSize = 32; + constexpr int VecSize = 4; + const int warp_num_per_block = blockDim.x / kWarpSize; + const int warp_id_in_block = threadIdx.x / kWarpSize; + const int lane_id = threadIdx.x % kWarpSize; + + Vec4T emb; + for (int emb_id = warp_num_per_block * blockIdx.x + warp_id_in_block; + emb_id < batch; emb_id += gridDim.x * warp_num_per_block) { + T* const src_ptr = src_ptrs[emb_id]; + T* dst_ptr = outputs + emb_id * emb_dim; + if (src_ptr != nullptr) { + for (int i = 0; VecSize * (kWarpSize * i + lane_id) < emb_dim; ++i) { + int idx4 = VecSize * (kWarpSize * i + lane_id); + emb.load(src_ptr + idx4); + emb.store(dst_ptr + idx4); + } + } + } +} + +template +__global__ void load_from_pointers_kernel( + int batch, + int emb_dim, + T* __restrict__ outputs, + T* const * __restrict__ src_ptrs) { + + for (int emb_id = blockIdx.x; emb_id < batch; emb_id += gridDim.x) { + T* const src_ptr = src_ptrs[emb_id]; + T* dst_ptr = outputs + emb_id * emb_dim; + if (src_ptr != nullptr) { + for (int i = threadIdx.x; i < emb_dim; i += blockDim.x) { + dst_ptr[i] = src_ptr[i]; + } + } + } +} + +void load_from_pointers(at::Tensor pointers, at::Tensor dst) { + int64_t num_total = pointers.size(0); + int64_t dim = dst.size(1); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + constexpr int kWarpSize = 32; + constexpr int MULTIPLIER = 4; + constexpr int BLOCK_SIZE_VEC = 64; + constexpr int WARP_PER_BLOCK = BLOCK_SIZE_VEC / kWarpSize; + auto &device_prop = DeviceProp::getDeviceProp(); + const int max_grid_size = + device_prop.num_sms * + (device_prop.max_thread_per_sm / BLOCK_SIZE_VEC); + + int grid_size = 0; + if (num_total / WARP_PER_BLOCK < max_grid_size) { + grid_size = (num_total - 1) / WARP_PER_BLOCK + 1; + } else if (num_total / WARP_PER_BLOCK > max_grid_size * MULTIPLIER) { + grid_size = max_grid_size * MULTIPLIER; + } else { + grid_size = max_grid_size; + } + + auto scalar_type = dst.dtype().toScalarType(); + auto value_type = scalartype_to_datatype(scalar_type); + DISPATCH_FLOAT_DATATYPE_FUNCTION(value_type, ValueType, [&] { + if (dim % 4 == 0) { + load_from_pointers_kernel_vec4 + <<>>( + num_total, dim, reinterpret_cast(dst.data_ptr()), + reinterpret_cast(pointers.data_ptr())); + } else { + int block_size = dim < device_prop.max_thread_per_block + ? dim + : device_prop.max_thread_per_block; + int grid_size = num_total; + load_from_pointers_kernel + <<>>( + num_total, dim, reinterpret_cast(dst.data_ptr()), + reinterpret_cast(pointers.data_ptr())); + } + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } // PYTHON WARP @@ -852,12 +1072,23 @@ void bind_dyn_emb_op(py::module &m) { py::arg("evicted_keys"), py::arg("evicted_values"), py::arg("evicted_score"), py::arg("d_evicted_counter"), py::arg("unique_key") = true, py::arg("ignore_evict_strategy") = false); - + m.def("insert_and_evict_with_scores", &insert_and_evict_with_scores, + "Insert keys and values, evicting if necessary", py::arg("table"), + py::arg("n"), py::arg("keys"), py::arg("values"), + py::arg("evicted_keys"), py::arg("evicted_values"), + py::arg("evicted_score"), py::arg("d_evicted_counter"), + py::arg("unique_key") = true, py::arg("ignore_evict_strategy") = false, + py::arg("scores") = py::none()); m.def("accum_or_assign", &accum_or_assign, "Accumulate or assign values to the table", py::arg("table"), py::arg("n"), py::arg("keys"), py::arg("value_or_deltas"), py::arg("accum_or_assigns"), py::arg("score") = c10::nullopt, py::arg("ignore_evict_strategy") = false); + + m.def("find_and_initialize", &find_and_initialize, + "Find and initialize a key-value pair in the table", py::arg("table"), + py::arg("n"), py::arg("keys"), py::arg("values"), + py::arg("initializer_args") = py::none()); m.def("find_or_insert", &find_or_insert, "Find or insert a key-value pair in the table", py::arg("table"), @@ -872,6 +1103,17 @@ void bind_dyn_emb_op(py::module &m) { py::arg("score") = py::none(), py::arg("unique_key") = true, py::arg("ignore_evict_strategy") = false); + m.def("find_pointers", &find_pointers, + "Find a key-value pair in the table , and return every " + "value's ptr", + py::arg("table"), py::arg("n"), py::arg("keys"), py::arg("values"), py::arg("founds"), + py::arg("score") = py::none()); + + m.def("find_pointers_with_scores", &find_pointers_with_scores, + "Find a key-value pair in the table , and return every " + "value's ptr", + py::arg("table"), py::arg("n"), py::arg("keys"), py::arg("values"), py::arg("founds"), + py::arg("scores") = py::none()); m.def("assign", &assign, "Assign values to the table based on keys", py::arg("table"), py::arg("n"), py::arg("keys"), py::arg("values"), py::arg("score") = c10::nullopt, py::arg("unique_key") = true); @@ -958,60 +1200,27 @@ void bind_dyn_emb_op(py::module &m) { py::arg("unique_buffer"), py::arg("unique_indices"), py::arg("inverse_indices"), py::arg("biased_offsets"), py::arg("dim"), py::arg("tables_num"), py::arg("batch_size"), py::arg("num_feature"), - py::arg("num_key")); - - m.def("lookup_forward_dense", - (void (*)(std::vector>, - const at::Tensor, const at::Tensor, const py::list, - const std::vector &, - at::Tensor, int, int, int, bool, const at::Tensor, - const at::Tensor, const at::Tensor, const at::Tensor, - const at::Tensor, const at::Tensor, const at::Tensor, - const at::Tensor, int, - std::shared_ptr)) & - lookup_forward_dense, - "lookup forward dense for duplicated keys", py::arg("tables"), - py::arg("indices"), py::arg("offsets"), py::arg("scores"), - py::arg("table_offsets_in_feature"), py::arg("table_offsets"), - py::arg("table_num"), py::arg("batch_size"), py::arg("dim"), - py::arg("use_index_dedup"), py::arg("unique_idx"), - py::arg("reverse_idx"), py::arg("h_unique_nums"), - py::arg("d_unique_nums"), py::arg("h_unique_offsets"), - py::arg("d_unique_offsets"), py::arg("unique_embs"), - py::arg("output_embs"), py::arg("device_num_sms"), - py::arg("unique_op")); - - m.def("lookup_forward_dense", - (void (*)(std::vector>, - const at::Tensor, const at::Tensor, const std::vector &, - int, int, int, const at::Tensor, const at::Tensor, - const at::Tensor)) & - lookup_forward_dense, - "lookup forward dense for globally deduplicated keys", - py::arg("tables"), py::arg("indices"), py::arg("offsets"), - py::arg("table_offsets_in_feature"), py::arg("table_num"), - py::arg("batch_size"), py::arg("dim"), py::arg("h_unique_offsets"), - py::arg("unique_embs"), py::arg("output_embs")); - - m.def("lookup_backward_dense", &lookup_backward_dense, - "lookup backward for dense/sequence", py::arg("indices"), - py::arg("grads"), py::arg("dim"), py::arg("table_offsets"), - py::arg("unique_indices"), py::arg("unique_grads")); - - - m.def("lookup_backward_dense_dedup", &lookup_backward_dense_dedup, - "lookup backward for dedup dense/sequence", py::arg("grads"), - py::arg("unique_indices"), py::arg("reverse_idx"), py::arg("dim"), - py::arg("unique_grads"), py::arg("device_num_sms")); + py::arg("num_key"), py::arg("combiner")); + m.def("dedup_input_indices", &dedup_input_indices, - "duplicate indices from a given list or array of indices", - py::arg("indices"), py::arg("offset"), - py::arg("h_table_offsets_in_feature"), - py::arg("d_table_offsets_in_feature"), py::arg("table_num"), - py::arg("local_batch_size"), py::arg("reverse_idx"), - py::arg("h_unique_nums"), py::arg("d_unique_nums"), - py::arg("h_unique_offsets"), py::arg("d_unique_offsets"), - py::arg("unique_idx"), py::arg("new_offsets"), py::arg("new_lengths"), - py::arg("device_num_sms"), py::arg("unique_op")); + "duplicate indices from a given list or array of indices", + py::arg("indices"), py::arg("offset"), + py::arg("h_table_offsets_in_feature"), + py::arg("d_table_offsets_in_feature"), py::arg("table_num"), + py::arg("local_batch_size"), py::arg("reverse_idx"), + py::arg("h_unique_nums"), py::arg("d_unique_nums"), + py::arg("h_unique_offsets"), py::arg("d_unique_offsets"), + py::arg("unique_idx"), py::arg("new_offsets"), py::arg("new_lengths"), + py::arg("device_num_sms"), py::arg("unique_op"), + py::arg("frequency_counters") = c10::nullopt, + py::arg("input_frequencies") = c10::nullopt); + + m.def("reduce_grads", &reduce_grads, + "reduce grads", py::arg("indices"), py::arg("grads"), py::arg("segment_range"), py::arg("h_segment_range") + ); + + m.def("load_from_pointers", &load_from_pointers, + "load from pointers to dst.", py::arg("pointers"), py::arg("dst") + ); } diff --git a/corelib/dynamicemb/src/dynamic_variable_base.h b/corelib/dynamicemb/src/dynamic_variable_base.h index 600785b5e..062995a39 100644 --- a/corelib/dynamicemb/src/dynamic_variable_base.h +++ b/corelib/dynamicemb/src/dynamic_variable_base.h @@ -23,6 +23,7 @@ #include #include #include +#include namespace dyn_emb { @@ -131,6 +132,10 @@ class DynamicVariableBase { bool unique_key = true, bool ignore_evict_strategy = false) = 0; + virtual void find_and_initialize( + const size_t n, const void *keys, void **value_ptrs, void *values, + bool *founds, std::optional initializer_args, const cudaStream_t& stream) = 0; + virtual void assign(const size_t n, const void *keys, // (n) const void *values, // (n, DIM) @@ -179,6 +184,20 @@ class DynamicVariableBase { uint64_t threshold, uint64_t* d_counter, cudaStream_t stream = 0) const = 0; + + virtual void lock(const size_t n, + const void* keys, // (n) + void** locked_keys_ptr, // (n) + bool* flags = nullptr, // (n) + void* scores = nullptr, // (n) + cudaStream_t stream = 0) = 0; + + virtual void unlock(const size_t n, + void** locked_keys_ptr, // (n) + const void* keys, // (n) + bool* flags = nullptr, // (n) + cudaStream_t stream = 0) = 0; + virtual curandState* get_curand_states() const = 0; virtual const InitializerArgs& get_initializer_args() const = 0; virtual const int optstate_dim() const = 0; diff --git a/corelib/dynamicemb/src/hkv_variable.cuh b/corelib/dynamicemb/src/hkv_variable.cuh index 972afd7a2..9d97160c2 100644 --- a/corelib/dynamicemb/src/hkv_variable.cuh +++ b/corelib/dynamicemb/src/hkv_variable.cuh @@ -28,6 +28,7 @@ #include #include #include "lookup_kernel.cuh" +#include "initializer.cuh" namespace { @@ -118,161 +119,6 @@ __global__ static void setup_kernel(unsigned long long seed, curand_init(seed, grid.thread_rank(), 0, &states[grid.thread_rank()]); } -struct UniformEmbeddingGenerator { - struct Args { - curandState* state; - float lower; - float upper; - }; - - DEVICE_INLINE UniformEmbeddingGenerator(Args args): load_(false), state_(args.state), - lower(args.lower), upper(args.upper) {} - - DEVICE_INLINE float generate(int64_t vec_id) { - if (!load_) { - localState_ = state_[GlobalThreadId()]; - load_ = true; - } - auto tmp = curand_uniform_double(&this->localState_); - return static_cast((upper - lower) * tmp + lower); - } - - DEVICE_INLINE void destroy() { - if (load_) { - state_[GlobalThreadId()] = localState_; - } - } - - bool load_; - curandState localState_; - curandState* state_; - float lower; - float upper; -}; - -struct NormalEmbeddingGenerator { - struct Args { - curandState* state; - float mean; - float std_dev; - }; - - DEVICE_INLINE - NormalEmbeddingGenerator(Args args): load_(false), state_(args.state), - mean(args.mean), std_dev(args.std_dev) {} - - DEVICE_INLINE - float generate(int64_t vec_id) { - if (!load_) { - localState_ = state_[GlobalThreadId()]; - load_ = true; - } - auto tmp = curand_normal_double(&this->localState_); - return static_cast(std_dev * tmp + mean); - } - - DEVICE_INLINE void destroy() { - if (load_) { - state_[GlobalThreadId()] = localState_; - } - } - - bool load_; - curandState localState_; - curandState* state_; - float mean; - float std_dev; -}; - -struct TruncatedNormalEmbeddingGenerator { - struct Args { - curandState* state; - float mean; - float std_dev; - float lower; - float upper; - }; - - DEVICE_INLINE - TruncatedNormalEmbeddingGenerator(Args args): load_(false), state_(args.state), - mean(args.mean), std_dev(args.std_dev), lower(args.lower), upper(args.upper) {} - - DEVICE_INLINE - float generate(int64_t vec_id) { - if (!load_) { - localState_ = state_[GlobalThreadId()]; - load_ = true; - } - auto l = normcdf((lower - mean) / std_dev); - auto u = normcdf((upper - mean) / std_dev); - u = 2 * u - 1; - l = 2 * l - 1; - float tmp = curand_uniform_double(&this->localState_); - tmp = tmp * (u - l) + l; - tmp = erfinv(tmp); - tmp *= scale * std_dev; - tmp += mean; - tmp = max(tmp, lower); - tmp = min(tmp, upper); - return tmp; - } - - DEVICE_INLINE void destroy() { - if (load_) { - state_[GlobalThreadId()] = localState_; - } - } - - bool load_; - curandState localState_; - curandState* state_; - float mean; - float std_dev; - float lower; - float upper; - double scale = sqrt(2.0f); -}; - -template -struct MappingEmbeddingGenerator { - struct Args { - const K* keys; - uint64_t mod; - }; - - DEVICE_INLINE - MappingEmbeddingGenerator(Args args): mod(args.mod), keys(args.keys) {} - - DEVICE_INLINE - float generate(int64_t vec_id) { - K key = keys[vec_id]; - return static_cast(key % mod); - } - - DEVICE_INLINE void destroy() {} - - uint64_t mod; - const K* keys; -}; - -struct ConstEmbeddingGenerator { - struct Args { - float val; - }; - - DEVICE_INLINE - ConstEmbeddingGenerator(Args args): val(args.val) {} - - DEVICE_INLINE - float generate(int64_t vec_id) { - return val; - } - - DEVICE_INLINE void destroy() {} - - float val; -}; - template struct OptStateInitializer { SizeType dim; @@ -575,6 +421,58 @@ void HKVVariable::accum_or_assign( DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } +template +void HKVVariable::find_and_initialize( + const size_t n, const void *keys, void **value_ptrs, void *values, + bool *d_found, std::optional initializer_args_, const cudaStream_t& stream) { + if (n == 0) + return; + int dim = dim_; + this->find_pointers(n, keys, value_ptrs, d_found, nullptr, stream); + auto &device_prop = DeviceProp::getDeviceProp(); + int block_size = dim < device_prop.max_thread_per_block + ? dim + : device_prop.max_thread_per_block; + int grid_size = device_prop.num_sms * (device_prop.max_thread_per_sm / block_size); + + auto &init_args = initializer_args_.has_value() ? initializer_args_.value() : initializer_args; + auto &initializer_ = init_args.mode; + if (initializer_ == "normal") { + using Generator = NormalEmbeddingGenerator; + auto generator_args = typename Generator::Args {curand_states_, init_args.mean, init_args.std_dev}; + load_or_initialize_embeddings_kernel + <<>>( + n, dim, reinterpret_cast(values), reinterpret_cast(value_ptrs), d_found, generator_args); + } else if (initializer_ == "truncated_normal") { + using Generator = TruncatedNormalEmbeddingGenerator; + auto generator_args = typename Generator::Args {curand_states_, init_args.mean, init_args.std_dev, init_args.lower, init_args.upper}; + load_or_initialize_embeddings_kernel + <<>>( + n, dim, reinterpret_cast(values), reinterpret_cast(value_ptrs), d_found, generator_args); + } else if (initializer_ == "uniform") { + using Generator = UniformEmbeddingGenerator; + auto generator_args = typename Generator::Args {curand_states_, init_args.lower, init_args.upper}; + load_or_initialize_embeddings_kernel + <<>>( + n, dim, reinterpret_cast(values), reinterpret_cast(value_ptrs), d_found, generator_args); + } else if (initializer_ == "debug") { + using Generator = MappingEmbeddingGenerator; + auto generator_args = typename Generator::Args {reinterpret_cast(keys), 100000}; + load_or_initialize_embeddings_kernel + <<>>( + n, dim, reinterpret_cast(values), reinterpret_cast(value_ptrs), d_found, generator_args); + } else if (initializer_ == "constant") { + using Generator = ConstEmbeddingGenerator; + auto generator_args = typename Generator::Args {init_args.value}; + load_or_initialize_embeddings_kernel + <<>>( + n, dim, reinterpret_cast(values), reinterpret_cast(value_ptrs), d_found, generator_args); + } else { + throw std::runtime_error("Unrecognized initializer {" + initializer_ + "}"); + } + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + template void HKVVariable::find_or_insert( const size_t n, const void *keys, void **value_ptrs, void *values, @@ -792,6 +690,39 @@ void HKVVariable::export_batch_matched( DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } +template +void HKVVariable::lock( + const size_t n, + const void* keys, // (n) + void** locked_keys_ptr, // (n) + bool* flags, // (n) + void* scores, + cudaStream_t stream +) { + hkv_table_->lock_keys( + n, + reinterpret_cast(keys), + reinterpret_cast(locked_keys_ptr), + flags, stream, (uint64_t*)scores); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void HKVVariable::unlock( + const size_t n, + void** locked_keys_ptr, // (n) + const void* keys, // (n) + bool* flags, // (n) + cudaStream_t stream +) { + hkv_table_->unlock_keys( + n, + reinterpret_cast(locked_keys_ptr), + reinterpret_cast(keys), + flags, stream); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + template curandState* HKVVariable::get_curand_states() const { return curand_states_; diff --git a/corelib/dynamicemb/src/hkv_variable.h b/corelib/dynamicemb/src/hkv_variable.h index 3e71eb2ef..02a5c4a64 100644 --- a/corelib/dynamicemb/src/hkv_variable.h +++ b/corelib/dynamicemb/src/hkv_variable.h @@ -84,6 +84,10 @@ class HKVVariable : public DynamicVariableBase { cudaStream_t stream = 0, bool unique_key = true, bool ignore_evict_strategy = false) override; + void find_and_initialize( + const size_t n, const void *keys, void **value_ptrs, void *values, + bool *founds, std::optional initializer_args, const cudaStream_t& stream) override; + void find_or_insert_pointers(const size_t n, const void *keys, // (n) void **value_ptrs, // (n * ptrs) bool *d_found, // (n * 1) @@ -140,6 +144,21 @@ class HKVVariable : public DynamicVariableBase { uint64_t threshold, uint64_t* d_counter, cudaStream_t stream = 0) const override; + + void lock( + const size_t n, + const void* keys, // (n) + void** locked_keys_ptr, // (n) + bool* flags = nullptr, // (n) + void* scores = nullptr, // (n) + cudaStream_t stream = 0) override; + + void unlock( + const size_t n, + void** locked_keys_ptr, // (n) + const void* keys, // (n) + bool* flags = nullptr, // (n) + cudaStream_t stream = 0) override; curandState* get_curand_states() const override; const InitializerArgs& get_initializer_args() const override; diff --git a/corelib/dynamicemb/src/index_calculation.cu b/corelib/dynamicemb/src/index_calculation.cu index 286ed8ba2..d76fbca04 100644 --- a/corelib/dynamicemb/src/index_calculation.cu +++ b/corelib/dynamicemb/src/index_calculation.cu @@ -17,9 +17,11 @@ #include "check.h" #include "index_calculation.h" +#include "utils.h" #include #include #include + namespace { // anonymous namespace template @@ -300,4 +302,277 @@ void SegmentedUniqueDevice::operator()( DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } +template +__global__ void get_table_range_kernel(int64_t num_table, + int64_t feature_x_batch, + InT const *__restrict__ offsets, + OutT const *__restrict__ feature_offsets, + OutT *__restrict__ table_range) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid < num_table + 1) { + OutT num_feature = feature_offsets[num_table]; + int64_t batch = feature_x_batch / num_feature; + OutT feature_offset = feature_offsets[tid]; + int64_t feature_x_batch_offset = feature_offset * batch; + table_range[tid] = static_cast(offsets[feature_x_batch_offset]); + } +} + +at::Tensor get_table_range(at::Tensor offsets, at::Tensor feature_offsets) { + if (!offsets.is_cuda()) { + throw std::runtime_error("Tensor must be on CUDA device."); + } + if (!feature_offsets.is_cuda()) { + throw std::runtime_error( + "Tensor must be on CUDA device."); + } + int64_t feature_x_batch = offsets.size(0) - 1; + int64_t num_table = feature_offsets.size(0) - 1; + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + at::Tensor table_range = at::empty_like(feature_offsets); + + int block_size = 128; + if (num_table + 1 < block_size) { + block_size = num_table + 1; + } + int grid_size = (num_table + block_size) / block_size; + auto offset_type = scalartype_to_datatype(offsets.dtype().toScalarType()); + auto range_type = + scalartype_to_datatype(feature_offsets.dtype().toScalarType()); + DISPATCH_OFFSET_INT_TYPE(offset_type, offset_t, [&] { + DISPATCH_OFFSET_INT_TYPE(range_type, range_t, [&] { + get_table_range_kernel + <<>>( + num_table, feature_x_batch, + reinterpret_cast(offsets.data_ptr()), + reinterpret_cast(feature_offsets.data_ptr()), + reinterpret_cast(table_range.data_ptr())); + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + return table_range; +} + +std::tuple +segmented_unique( + at::Tensor keys, at::Tensor segment_range, + std::shared_ptr unique_op, + const c10::optional evict_strategy = c10::nullopt, + const c10::optional frequency_counts_uint64 = c10::nullopt) { + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t num_total = keys.size(0); + size_t unique_op_capacity = unique_op->get_capacity(); + if (num_total * 2 > unique_op_capacity) { + at::Tensor new_keys = at::empty({num_total * 2}, keys.options()); + at::Tensor new_vals = + at::empty({num_total * 2}, + at::TensorOptions().dtype(at::kLong).device(keys.device())); + unique_op->reset_capacity(new_keys, new_vals, num_total * 2, stream); + } + + at::Tensor h_segment_range = + at::empty(segment_range.sizes(), + segment_range.options().device(at::kCPU).pinned_memory(true)); + h_segment_range.copy_(segment_range, /*non_blocking=*/true); + + int table_num = segment_range.size(0) - 1; + + // Determine score strategy + bool is_lfu_enabled = false; + if (evict_strategy.has_value()) { + is_lfu_enabled = evict_strategy.value() == EvictStrategy::kLfu; + } + // Create vector of tensors for per-table frequency output + bool need_frequency_output = false; + need_frequency_output = is_lfu_enabled || frequency_counts_uint64.has_value(); + // Use single shared buffer instead of table_num separate buffers + at::Tensor shared_unique_buffer = at::empty(num_total, keys.options()); + at::Tensor shared_frequency_buffer; + if (need_frequency_output) { + shared_frequency_buffer = at::zeros( + num_total, at::TensorOptions().dtype(at::kLong).device(keys.device())); + } + at::Tensor d_unique_nums = at::empty(table_num, segment_range.options()); + at::Tensor d_unique_indices_table_range = + at::zeros(table_num + 1, segment_range.options()); + + auto unique_num_type = scalartype_to_datatype( + convertTypeMetaToScalarType(d_unique_nums.dtype())); + auto unique_offset_type = scalartype_to_datatype( + convertTypeMetaToScalarType(d_unique_indices_table_range.dtype())); + auto inverse_idx = at::empty(num_total, segment_range.options()); + + // sync for h_segment_range + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + + for (int i = 0; i < table_num; ++i) { + int64_t indices_begin = h_segment_range[i].item(); + int64_t indices_end = h_segment_range[i + 1].item(); + int64_t indices_length = indices_end - indices_begin; + + if (indices_length == 0) { + DEMB_CUDA_CHECK(cudaMemsetAsync( + reinterpret_cast(d_unique_nums.data_ptr()) + i, 0, + sizeof(int64_t), stream)); + dyn_emb::add_offset(d_unique_nums.data_ptr(), + d_unique_indices_table_range.data_ptr(), i, + unique_num_type, unique_offset_type, stream); + } else { + at::Tensor tmp_indices = keys.slice(0, indices_begin, indices_end); + at::Tensor tmp_inverse_idx = + inverse_idx.slice(0, indices_begin, indices_end); + at::Tensor tmp_d_unique_num = d_unique_nums.slice(0, i, table_num); + + at::Tensor previous_d_unique_num = + d_unique_indices_table_range.slice(0, i, table_num + 1); + + at::Tensor tmp_unique_buffer_slice = + shared_unique_buffer.slice(0, indices_begin, indices_end); + + at::Tensor tmp_frequency_counts_uint64; + at::Tensor tmp_frequency_output_slice; + if (frequency_counts_uint64.has_value()) { + // LFU mode: use input frequency counts + tmp_frequency_counts_uint64 = frequency_counts_uint64.value().slice( + 0, indices_begin, indices_end); + tmp_frequency_output_slice = + shared_frequency_buffer.slice(0, indices_begin, indices_end); + unique_op->unique(tmp_indices, indices_length, tmp_inverse_idx, + tmp_unique_buffer_slice, tmp_d_unique_num, stream, + previous_d_unique_num, tmp_frequency_output_slice, + tmp_frequency_counts_uint64); + } else if (need_frequency_output) { + tmp_frequency_output_slice = + shared_frequency_buffer.slice(0, indices_begin, indices_end); + unique_op->unique(tmp_indices, indices_length, tmp_inverse_idx, + tmp_unique_buffer_slice, tmp_d_unique_num, stream, + previous_d_unique_num, tmp_frequency_output_slice); + } else { + // Non-LFU mode: call unique without frequency counting + unique_op->unique(tmp_indices, indices_length, tmp_inverse_idx, + tmp_unique_buffer_slice, tmp_d_unique_num, stream, + previous_d_unique_num); + } + + dyn_emb::add_offset(d_unique_nums.data_ptr(), + d_unique_indices_table_range.data_ptr(), i, + unique_num_type, unique_offset_type, stream); + } + } + + at::Tensor h_unique_indices_table_range = + at::empty(table_num + 1, segment_range.options().device(at::kCPU)); + AT_CUDA_CHECK(cudaMemcpyAsync(h_unique_indices_table_range.data_ptr(), + d_unique_indices_table_range.data_ptr(), + (d_unique_indices_table_range.size(0)) * + d_unique_indices_table_range.element_size(), + cudaMemcpyDeviceToHost, stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + + int64_t unique_embs_offset = 0; + int64_t num_unique_total = + h_unique_indices_table_range[table_num].item(); + at::Tensor unique_keys = at::empty(num_unique_total, keys.options()); + at::Tensor output_scores; + output_scores = at::Tensor(); + if (need_frequency_output) { + output_scores = at::empty(num_unique_total, keys.options()); + } + for (int i = 0; i < table_num; ++i) { + int64_t tmp_unique_num = + h_unique_indices_table_range[i + 1].item() - + h_unique_indices_table_range[i].item(); + if (tmp_unique_num != 0) { + int64_t indices_begin = h_segment_range[i].item(); + + void *dst_ptr = reinterpret_cast(unique_keys.data_ptr()) + + unique_embs_offset * unique_keys.element_size(); + void *src_ptr = reinterpret_cast(shared_unique_buffer.data_ptr()) + + indices_begin * shared_unique_buffer.element_size(); + size_t copy_size = tmp_unique_num * unique_keys.element_size(); + AT_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, copy_size, + cudaMemcpyDeviceToDevice, stream)); + if (need_frequency_output) { + void *dst_ptr = reinterpret_cast(output_scores.data_ptr()) + + unique_embs_offset * output_scores.element_size(); + void *src_ptr = reinterpret_cast(shared_frequency_buffer.data_ptr()) + + indices_begin * shared_frequency_buffer.element_size(); + size_t copy_size = tmp_unique_num * output_scores.element_size(); + AT_CUDA_CHECK(cudaMemcpyAsync(dst_ptr, src_ptr, copy_size, + cudaMemcpyDeviceToDevice, stream)); + } + } + unique_embs_offset += tmp_unique_num; + } + return std::make_tuple(unique_keys, inverse_idx, d_unique_indices_table_range, + h_unique_indices_table_range, output_scores); +} + +void select(at::Tensor flags, at::Tensor inputs, at::Tensor outputs, + at::Tensor num_selected) { + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t num_total = inputs.size(0); + auto scalar_type = inputs.dtype().toScalarType(); + auto key_type = scalartype_to_datatype(scalar_type); + auto num_select_iter_type = + scalartype_to_datatype(num_selected.dtype().toScalarType()); + + DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { + DISPATCH_INTEGER_DATATYPE_FUNCTION( + num_select_iter_type, NumSelectedIteratorT, [&] { + select_async( + num_total, flags.data_ptr(), + reinterpret_cast(inputs.data_ptr()), + reinterpret_cast(outputs.data_ptr()), + reinterpret_cast(num_selected.data_ptr()), + inputs.device(), stream); + }); + }); +} + +void select_index(at::Tensor flags, at::Tensor output_indices, + at::Tensor num_selected) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t num_total = output_indices.size(0); + auto scalar_type = output_indices.dtype().toScalarType(); + auto key_type = scalartype_to_datatype(scalar_type); + auto num_select_iter_type = + scalartype_to_datatype(num_selected.dtype().toScalarType()); + + DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { + DISPATCH_INTEGER_DATATYPE_FUNCTION( + num_select_iter_type, NumSelectedIteratorT, [&] { + select_index_async( + num_total, flags.data_ptr(), + reinterpret_cast(output_indices.data_ptr()), + reinterpret_cast(num_selected.data_ptr()), + output_indices.device(), stream); + }); + }); +} + } // namespace dyn_emb + +void bind_index_calculation_op(py::module &m) { + m.def("get_table_range", &dyn_emb::get_table_range, + "Make offsets from scope into scope", + py::arg("offsets"), py::arg("feature_offsets")); + + m.def("segmented_unique", &dyn_emb::segmented_unique, + "Dose segmented unique operation on keys with segment_range, return " + "tuple", + py::arg("keys"), py::arg("segment_range"), py::arg("unique_op"), + py::arg("evict_strategy") = c10::nullopt, + py::arg("frequency_counts_uint64") = c10::nullopt); + + m.def("select", &dyn_emb::select, + "Select items in inputs which flags are true.", py::arg("flags"), + py::arg("inputs"), py::arg("outputs"), py::arg("num_selected")); + m.def("select_index", &dyn_emb::select_index, + "Select items' indices where flags are true.", py::arg("flags"), + py::arg("output_indices"), py::arg("num_selected")); +} diff --git a/corelib/dynamicemb/src/index_calculation.h b/corelib/dynamicemb/src/index_calculation.h index b14a86e55..2408452cf 100644 --- a/corelib/dynamicemb/src/index_calculation.h +++ b/corelib/dynamicemb/src/index_calculation.h @@ -17,6 +17,8 @@ #pragma once #include "torch_utils.h" +#include "unique_op.h" +#include "lookup_forward.h" #include "utils.h" #include #include @@ -82,4 +84,67 @@ struct SegmentedUniqueDevice { cudaStream_t &stream); }; -} // namespace dyn_emb \ No newline at end of file +template +void select_async( + int64_t num_items, + bool const * d_flags, + T const * d_input, + T* d_output, + NumSelectedIteratorT* d_num_select, + at::Device const& device, + cudaStream_t const& stream +) { + + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + + // 1. get the size of temp storage. + cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_input, d_flags, d_output, + d_num_select, num_items, stream); + + // 2. allocate the temp storage. + d_temp_storage = at::empty({static_cast(temp_storage_bytes)}, + at::TensorOptions().dtype(torch::kChar).device(device)).data_ptr(); + + // 3. select + cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + d_input, d_flags, d_output, + d_num_select, num_items, stream); +} + +template +void select_index_async( + int64_t num_items, + bool const * d_flags, + T* d_output, + NumSelectedIteratorT* d_num_select, + at::Device const& device, + cudaStream_t const& stream +) { + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + cub::CountingInputIterator counting_iter(0); + + // 1. get the size of temp storage. + cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + counting_iter, d_flags, d_output, + d_num_select, num_items, stream); + + // 2. allocate the temp storage. + d_temp_storage = at::empty({static_cast(temp_storage_bytes)}, + at::TensorOptions().dtype(torch::kChar).device(device)).data_ptr(); + + // 3. select + cub::DeviceSelect::Flagged( + d_temp_storage, temp_storage_bytes, + counting_iter, d_flags, d_output, + d_num_select, num_items, stream); +} + +} // namespace dyn_emb + +void bind_index_calculation_op(py::module &m); \ No newline at end of file diff --git a/corelib/dynamicemb/src/initializer.cu b/corelib/dynamicemb/src/initializer.cu new file mode 100644 index 000000000..ddcf5851e --- /dev/null +++ b/corelib/dynamicemb/src/initializer.cu @@ -0,0 +1,236 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "initializer.cuh" + +namespace py = pybind11; + +namespace dyn_emb { + +__global__ void init_curand_state_kernel( + unsigned long long seed, + curandState *states +) { + auto grid = cooperative_groups::this_grid(); + curand_init(seed, grid.thread_rank(), 0, &states[grid.thread_rank()]); +} + +class CurandStateContext { + +public: + CurandStateContext() { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto &deviceProp = DeviceProp::getDeviceProp(); + num_worker_ = deviceProp.total_threads; + CUDACHECK(cudaMallocAsync( + &states_, sizeof(curandState) * num_worker_, stream)); + std::random_device rd; + auto seed = rd(); + int block_size = deviceProp.max_thread_per_block; + int grid_size = num_worker_ / block_size; + init_curand_state_kernel<<>>(seed, states_); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } + + ~CurandStateContext() { + // not async to avoid stream destroy case. + CUDACHECK(cudaDeviceSynchronize()); + CUDACHECK(cudaFree(states_)); + } + + int64_t num_worker() { + return num_worker_; + } + + curandState* ptr() { return states_; } + +private: + curandState* states_; + int64_t num_worker_; +}; + +template +__global__ void initialize_with_index_addressor_kernel( + int64_t num, + int64_t dim, + int64_t stide, + ValueT * __restrict__ buffer, + IndexT const * __restrict__ indices, + typename GeneratorT::Args generator_args +) { + + GeneratorT gen(generator_args); + int64_t num_task = num * dim; + int64_t task_id = blockIdx.x * blockDim.x + threadIdx.x; + + for (; task_id < num_task; task_id += gridDim.x * blockDim.x) { + int64_t emb_id = task_id / dim; + int64_t index = indices[emb_id]; + ValueT * dst = buffer + index * stide; + auto tmp = gen.generate(index); + dst[task_id % dim] = TypeConvertFunc::convert(tmp); + } + gen.destroy(); +} + +template +void initialize_with_generator( + at::Tensor buffer, + at::Tensor indices, + typename GeneratorT::Args generator_args, + int num_worker = -1 +) { + int num_dims = buffer.dim(); + if (num_dims != 2) { + throw std::runtime_error("Initializer'input buffer's dim have to be 2."); + } + if (buffer.stride(1) != 1) { + throw std::runtime_error("Initializer'input buffer has to be contiguous at dim1."); + } + int64_t num_total = indices.size(0); + int64_t dim = buffer.size(1); + int64_t stride = buffer.stride(0); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto &deviceProp = DeviceProp::getDeviceProp(); + + int block_size = deviceProp.max_thread_per_block; + int num_need = num_total * dim; + if (num_worker == -1) { + num_worker = deviceProp.total_threads; + } + if (num_worker > num_need) { + num_worker = num_need; + } + int grid_size = (num_worker - 1) / block_size + 1; + + auto value_type = + scalartype_to_datatype(convertTypeMetaToScalarType(buffer.dtype())); + auto index_type = + scalartype_to_datatype(convertTypeMetaToScalarType(indices.dtype())); + DISPATCH_FLOAT_DATATYPE_FUNCTION(value_type, ValueType, [&] { + DISPATCH_INTEGER_DATATYPE_FUNCTION(index_type, IndexType, [&] { + initialize_with_index_addressor_kernel + <<>>( + num_total, dim, stride, reinterpret_cast(buffer.data_ptr()), + reinterpret_cast(indices.data_ptr()), generator_args); + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void normal_init( + at::Tensor buffer, + at::Tensor indices, + CurandStateContext& curand_state_context, + float mean, + float std_dev +) { + + using GeneratorT = NormalEmbeddingGenerator; + auto generator_args = typename GeneratorT::Args {curand_state_context.ptr(), mean, std_dev}; + int num_worker = curand_state_context.num_worker(); + initialize_with_generator(buffer, indices, generator_args, num_worker); +} + +void truncated_normal_init( + at::Tensor buffer, + at::Tensor indices, + CurandStateContext& curand_state_context, + float mean, + float std_dev, + float lower, + float upper +) { + using GeneratorT = TruncatedNormalEmbeddingGenerator; + auto generator_args = typename GeneratorT::Args {curand_state_context.ptr(), mean, std_dev, lower, upper}; + int num_worker = curand_state_context.num_worker(); + initialize_with_generator(buffer, indices, generator_args, num_worker); +} + +void uniform_init( + at::Tensor buffer, + at::Tensor indices, + CurandStateContext& curand_state_context, + float lower, + float upper +) { + using GeneratorT = UniformEmbeddingGenerator; + auto generator_args = typename GeneratorT::Args {curand_state_context.ptr(), lower, upper}; + int num_worker = curand_state_context.num_worker(); + initialize_with_generator(buffer, indices, generator_args, num_worker); +} + +void const_init( + at::Tensor buffer, + at::Tensor indices, + float value +) { + using GeneratorT = ConstEmbeddingGenerator; + auto generator_args = typename GeneratorT::Args {value}; + initialize_with_generator(buffer, indices, generator_args); +} + +void debug_init( + at::Tensor buffer, + at::Tensor indices, + at::Tensor keys +) { + auto key_type = + scalartype_to_datatype(convertTypeMetaToScalarType(keys.dtype())); + DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, KeyType, [&] { + using GeneratorT = MappingEmbeddingGenerator; + auto generator_args = typename GeneratorT::Args {reinterpret_cast(keys.data_ptr()), 100000}; + initialize_with_generator(buffer, indices, generator_args); + }); + +} + +} // namespace dyn_emb + +void bind_initializer_op(py::module &m) { + + py::class_(m, "CurandStateContext") + .def(py::init<>()) + .def("ptr", &dyn_emb::CurandStateContext::ptr, + py::return_value_policy::reference); + + m.def("normal_init", &dyn_emb::normal_init, + "Normal initializer", + py::arg("buffer"), py::arg("indices"), py::arg("curand_state_context"), py::arg("mean"), py::arg("std_dev")); + + m.def("truncated_normal_init", &dyn_emb::truncated_normal_init, + "Truncated normal initializer", + py::arg("buffer"), py::arg("indices"), py::arg("curand_state_context"), + py::arg("mean"), py::arg("std_dev"), py::arg("lower"), py::arg("upper")); + + m.def( + "uniform_init", &dyn_emb::uniform_init, + "Uniform initializer", + py::arg("buffer"), py::arg("indices"), py::arg("curand_state_context"), + py::arg("lower"), py::arg("upper")); + + m.def( + "const_init", &dyn_emb::const_init, + "Const initializer", + py::arg("buffer"), py::arg("indices"), py::arg("value")); + + m.def( + "debug_init", &dyn_emb::debug_init, + "Debug initializer", + py::arg("buffer"), py::arg("indices"), py::arg("keys")); +} \ No newline at end of file diff --git a/corelib/dynamicemb/src/initializer.cuh b/corelib/dynamicemb/src/initializer.cuh new file mode 100644 index 000000000..45d1efbf8 --- /dev/null +++ b/corelib/dynamicemb/src/initializer.cuh @@ -0,0 +1,190 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" +#include "check.h" +#include "lookup_kernel.cuh" +#include "torch_utils.h" + +namespace dyn_emb { + +struct UniformEmbeddingGenerator { + struct Args { + curandState* state; + float lower; + float upper; + }; + + DEVICE_INLINE UniformEmbeddingGenerator(Args args): load_(false), state_(args.state), + lower(args.lower), upper(args.upper) {} + + DEVICE_INLINE float generate(int64_t vec_id) { + if (!load_) { + localState_ = state_[GlobalThreadId()]; + load_ = true; + } + auto tmp = curand_uniform_double(&this->localState_); + return static_cast((upper - lower) * tmp + lower); + } + + DEVICE_INLINE void destroy() { + if (load_) { + state_[GlobalThreadId()] = localState_; + } + } + + bool load_; + curandState localState_; + curandState* state_; + float lower; + float upper; +}; + +struct NormalEmbeddingGenerator { + struct Args { + curandState* state; + float mean; + float std_dev; + }; + + DEVICE_INLINE + NormalEmbeddingGenerator(Args args): load_(false), state_(args.state), + mean(args.mean), std_dev(args.std_dev) {} + + DEVICE_INLINE + float generate(int64_t vec_id) { + if (!load_) { + localState_ = state_[GlobalThreadId()]; + load_ = true; + } + auto tmp = curand_normal_double(&this->localState_); + return static_cast(std_dev * tmp + mean); + } + + DEVICE_INLINE void destroy() { + if (load_) { + state_[GlobalThreadId()] = localState_; + } + } + + bool load_; + curandState localState_; + curandState* state_; + float mean; + float std_dev; +}; + +struct TruncatedNormalEmbeddingGenerator { + struct Args { + curandState* state; + float mean; + float std_dev; + float lower; + float upper; + }; + + DEVICE_INLINE + TruncatedNormalEmbeddingGenerator(Args args): load_(false), state_(args.state), + mean(args.mean), std_dev(args.std_dev), lower(args.lower), upper(args.upper) {} + + DEVICE_INLINE + float generate(int64_t vec_id) { + if (!load_) { + localState_ = state_[GlobalThreadId()]; + load_ = true; + } + auto l = normcdf((lower - mean) / std_dev); + auto u = normcdf((upper - mean) / std_dev); + u = 2 * u - 1; + l = 2 * l - 1; + float tmp = curand_uniform_double(&this->localState_); + tmp = tmp * (u - l) + l; + tmp = erfinv(tmp); + tmp *= scale * std_dev; + tmp += mean; + tmp = max(tmp, lower); + tmp = min(tmp, upper); + return tmp; + } + + DEVICE_INLINE void destroy() { + if (load_) { + state_[GlobalThreadId()] = localState_; + } + } + + bool load_; + curandState localState_; + curandState* state_; + float mean; + float std_dev; + float lower; + float upper; + double scale = sqrt(2.0f); +}; + +template +struct MappingEmbeddingGenerator { + struct Args { + const K* keys; + uint64_t mod; + }; + + DEVICE_INLINE + MappingEmbeddingGenerator(Args args): mod(args.mod), keys(args.keys) {} + + DEVICE_INLINE + float generate(int64_t vec_id) { + K key = keys[vec_id]; + return static_cast(key % mod); + } + + DEVICE_INLINE void destroy() {} + + uint64_t mod; + const K* keys; +}; + +struct ConstEmbeddingGenerator { + struct Args { + float val; + }; + + DEVICE_INLINE + ConstEmbeddingGenerator(Args args): val(args.val) {} + + DEVICE_INLINE + float generate(int64_t vec_id) { + return val; + } + + DEVICE_INLINE void destroy() {} + + float val; +}; + +} // namespace dyn_emb diff --git a/corelib/dynamicemb/src/lookup_backward.cu b/corelib/dynamicemb/src/lookup_backward.cu index 028b96d86..eaf6aa935 100644 --- a/corelib/dynamicemb/src/lookup_backward.cu +++ b/corelib/dynamicemb/src/lookup_backward.cu @@ -561,7 +561,7 @@ __global__ void wgrad_reduction_kernel(const Key_t *unique_indices, const Key_t *inverse_indices, const Key_t *biased_offset, const Value_t *grads, Value_t *unique_buffer, int dim, - int batch_size, int feature_num, int num_key) { + int batch_size, int feature_num, int num_key, int combiner) { const int warpsize = 32; int tid = threadIdx.x; @@ -570,19 +570,24 @@ wgrad_reduction_kernel(const Key_t *unique_indices, Key_t src_id = bs_upper_bound_sub_one( biased_offset, batch_size * feature_num + 1, (Key_t)i_ev); + Value_t pooling_factor = 1.0f; + if (combiner == 1) { + pooling_factor = Value_t(static_cast(biased_offset[src_id + 1] - biased_offset[src_id])); + } + const Value_t *src_ptr = grads + src_id * dim; Key_t dst_id = inverse_indices[i_ev]; Value_t *dst_ptr = unique_buffer + dst_id * dim; for (int i = tid % warpSize; i < dim; i += warpsize) { - Value_t value = atomicAdd(dst_ptr + i, src_ptr[i]); + Value_t value = atomicAdd(dst_ptr + i, src_ptr[i] / pooling_factor); } } } void backward(void *grads, void *unique_buffer, void *unique_indices, void *inverse_indices, void *biased_offset, const int dim, - const int batch_size, const int feature_num, const int num_key, + const int batch_size, const int feature_num, const int num_key, int combiner, DataType key_type, DataType value_type, cudaStream_t stream) { DISPATCH_INTEGER_DATATYPE_FUNCTION(key_type, key_t, [&] { DISPATCH_FLOAT_DATATYPE_FUNCTION(value_type, value_t, [&] { @@ -591,7 +596,7 @@ void backward(void *grads, void *unique_buffer, void *unique_indices, wgrad_reduction_kernel<<>>( (key_t *)unique_indices, (key_t *)inverse_indices, (key_t *)biased_offset, (value_t *)grads, (value_t *)unique_buffer, - dim, batch_size, feature_num, num_key); + dim, batch_size, feature_num, num_key, combiner); }); }); DEMB_CUDA_KERNEL_LAUNCH_CHECK(); diff --git a/corelib/dynamicemb/src/lookup_backward.h b/corelib/dynamicemb/src/lookup_backward.h index 7ad0db708..311185b5c 100644 --- a/corelib/dynamicemb/src/lookup_backward.h +++ b/corelib/dynamicemb/src/lookup_backward.h @@ -47,7 +47,7 @@ class LocalReduce { void backward(void *grads, void *unique_buffer, void *unique_indices, void *inverse_indices, void *biased_offset, const int dim, - const int batch_size, const int feature_num, const int num_key, + const int batch_size, const int feature_num, const int num_key, int combiner, DataType key_type, DataType value_type, cudaStream_t stream); void one_to_one_atomic(void *grads, void *unique_indices, void *reverse_indices, void *unique_grads, const int ev_size, diff --git a/corelib/dynamicemb/src/module_bind.cu b/corelib/dynamicemb/src/module_bind.cu index 66afe6cb5..e7a4a681b 100644 --- a/corelib/dynamicemb/src/module_bind.cu +++ b/corelib/dynamicemb/src/module_bind.cu @@ -24,6 +24,9 @@ void bind_unique_op(py::module& m); void bind_bucktiz_kernel_op(py::module& m); void bind_optimizer_kernel_op(py::module& m); void bind_utils(py::module& m); +void bind_index_calculation_op(py::module& m); +void bind_initializer_op(py::module &m); +void bind_table_operation(py::module &m); PYBIND11_MODULE(dynamicemb_extensions, m) { m.doc() = "DYNAMICEMB"; // Optional @@ -32,5 +35,8 @@ PYBIND11_MODULE(dynamicemb_extensions, m) { bind_unique_op(m); bind_bucktiz_kernel_op(m); bind_optimizer_kernel_op(m); + bind_index_calculation_op(m); + bind_initializer_op(m); bind_utils(m); + bind_table_operation(m); } diff --git a/corelib/dynamicemb/src/optimizer.cu b/corelib/dynamicemb/src/optimizer.cu index 46d002404..f4a62e1e1 100644 --- a/corelib/dynamicemb/src/optimizer.cu +++ b/corelib/dynamicemb/src/optimizer.cu @@ -37,8 +37,7 @@ constexpr int OPTIMIZER_BLOCKSIZE = 1024; void dynamic_emb_sgd_with_table( std::shared_ptr table, const uint64_t n, - const at::Tensor indices, const at::Tensor grads, const float lr, DataType weight_type, - const std::optional score) { + const at::Tensor indices, const at::Tensor grads, const float lr, DataType weight_type) { if (n == 0) return; TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); @@ -104,8 +103,7 @@ void dynamic_emb_adam_with_table( const uint64_t n, const at::Tensor indices, const at::Tensor grads, const float lr, const float beta1, const float beta2, const float eps, const float weight_decay, - const uint32_t iter_num, DataType weight_type, - const std::optional score) { + const uint32_t iter_num, DataType weight_type) { if (n == 0) return; TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); @@ -118,7 +116,7 @@ void dynamic_emb_adam_with_table( auto stream = at::cuda::getCurrentCUDAStream().stream(); find_pointers(ht, n, indices, vector_ptrs, founds); - + auto &device_prop = DeviceProp::getDeviceProp(grads.device().index()); int64_t dim = grads.size(1); @@ -176,7 +174,7 @@ void dynamic_emb_adagrad_with_table( const at::Tensor grads, const float lr, const float eps, - DataType weight_type,const std::optional score){ + DataType weight_type){ if (n == 0) return; TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); @@ -244,7 +242,7 @@ void dynamic_emb_rowwise_adagrad_with_table( const at::Tensor grads, const float lr, const float eps, - DataType weight_type,const std::optional score) { + DataType weight_type){ if (n == 0) return; TORCH_CHECK(indices.is_cuda(), "indices must be a CUDA tensor"); TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); @@ -305,6 +303,352 @@ void dynamic_emb_rowwise_adagrad_with_table( DEMB_CUDA_KERNEL_LAUNCH_CHECK(); } +void dynamic_emb_sgd_with_pointer(at::Tensor grads, at::Tensor val_pointers, DataType val_type, float const lr) { + int64_t ev_nums = grads.size(0); + int64_t dim = grads.size(1); + if (ev_nums == 0) return; + TORCH_CHECK(val_pointers.is_cuda(), "val_pointers must be a CUDA tensor"); + TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto &device_prop = DeviceProp::getDeviceProp(grads.device().index()); + + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + + SgdVecOptimizer opt{lr}; + if (dim % 4 == 0) { + const int max_grid_size = + device_prop.num_sms * + (device_prop.max_thread_per_sm / OPTIMIZER_BLOCKSIZE_VEC); + const int warp_per_block = OPTIMIZER_BLOCKSIZE_VEC / WARPSIZE; + + int grid_size = 0; + if (ev_nums / warp_per_block < max_grid_size) { + grid_size = (ev_nums - 1) / warp_per_block + 1; + } else if (ev_nums / warp_per_block > max_grid_size * MULTIPLIER) { + grid_size = max_grid_size * MULTIPLIER; + } else { + grid_size = max_grid_size; + } + + auto kernel = update4_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; + int grid_size = ev_nums; + + auto kernel = update_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void dynamic_emb_adam_with_pointer( + at::Tensor grads, at::Tensor val_pointers, DataType val_type, int64_t state_dim, + const float lr, const float beta1, const float beta2, const float eps, + const float weight_decay, const uint32_t iter_num +) { + int64_t ev_nums = grads.size(0); + int64_t dim = grads.size(1); + if (ev_nums == 0) return; + TORCH_CHECK(val_pointers.is_cuda(), "val_pointers must be a CUDA tensor"); + TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto &device_prop = DeviceProp::getDeviceProp(grads.device().index()); + + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + AdamVecOptimizer opt{lr, + beta1, + beta2, + eps, + weight_decay, + iter_num}; + if (dim % 4 == 0) { + const int max_grid_size = + device_prop.num_sms * + (device_prop.max_thread_per_sm / OPTIMIZER_BLOCKSIZE_VEC); + const int warp_per_block = OPTIMIZER_BLOCKSIZE_VEC / WARPSIZE; + + int grid_size = 0; + if (ev_nums / warp_per_block < max_grid_size) { + grid_size = (ev_nums - 1) / warp_per_block + 1; + } else if (ev_nums / warp_per_block > max_grid_size * MULTIPLIER) { + grid_size = max_grid_size * MULTIPLIER; + } else { + grid_size = max_grid_size; + } + + auto kernel = update4_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; + int grid_size = ev_nums; + + auto kernel = update_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + +} + +void dynamic_emb_adagrad_with_pointer( + at::Tensor grads, at::Tensor val_pointers, DataType val_type, int64_t state_dim, + const float lr, const float eps) { + + int64_t ev_nums = grads.size(0); + int64_t dim = grads.size(1); + if (ev_nums == 0) return; + + TORCH_CHECK(val_pointers.is_cuda(), "val_pointers must be a CUDA tensor"); + TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto& device_prop = DeviceProp::getDeviceProp(grads.device().index()); + + auto grad_type = scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + + AdaGradVecOptimizer opt{lr, eps}; + + if (dim % 4 == 0) { + const int max_grid_size = device_prop.num_sms * (device_prop.max_thread_per_sm / OPTIMIZER_BLOCKSIZE_VEC); + const int warp_per_block = OPTIMIZER_BLOCKSIZE_VEC/WARPSIZE; + + int grid_size = 0; + if (ev_nums/warp_per_block < max_grid_size){ + grid_size = (ev_nums-1)/warp_per_block+1; + } + else if (ev_nums/warp_per_block > max_grid_size*MULTIPLIER){ + grid_size = max_grid_size*MULTIPLIER; + } + else{ + grid_size = max_grid_size; + } + + auto kernel = update4_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + + int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; + int grid_size = ev_nums; + + auto kernel = update_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + +} + +void dynamic_emb_rowwise_adagrad_with_pointer( + at::Tensor grads, at::Tensor val_pointers, DataType val_type, int64_t state_dim, + const float lr, const float eps) { + + int64_t ev_nums = grads.size(0); + int64_t dim = grads.size(1); + if (ev_nums == 0) return; + + TORCH_CHECK(val_pointers.is_cuda(), "val_pointers must be a CUDA tensor"); + TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto& device_prop = DeviceProp::getDeviceProp(grads.device().index()); + + auto grad_type = scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + + RowWiseAdaGradVecOptimizer opt {lr, eps}; + if (dim % 4 == 0) { + const int max_grid_size = device_prop.num_sms * (device_prop.max_thread_per_sm / OPTIMIZER_BLOCKSIZE_VEC); + const int warp_per_block = OPTIMIZER_BLOCKSIZE_VEC / WARPSIZE; + + int grid_size = 0; + if (ev_nums / warp_per_block < max_grid_size) { + grid_size = (ev_nums-1) / warp_per_block + 1; + } + else if (ev_nums / warp_per_block > max_grid_size * MULTIPLIER) { + grid_size = max_grid_size * MULTIPLIER; + } else { + grid_size = max_grid_size; + } + + auto kernel = update4_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + + } else { + + int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; + int grid_size = ev_nums; + int shared_memory_bytes = block_size * sizeof(float); + + auto kernel = update_kernel; + kernel<<>>( + ev_nums, dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(val_pointers.data_ptr()), nullptr, opt); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); + } + }); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template < +typename opt_t, +typename g_t, +typename w_t +> +void fused_update(const opt_t& opt, at::Tensor grads, at::Tensor values) { + + int64_t ev_nums = grads.size(0); + int64_t dim = grads.size(1); + int64_t val_dim = values.size(1); + if (ev_nums == 0) return; + TORCH_CHECK(values.is_cuda(), "values must be a CUDA tensor"); + TORCH_CHECK(grads.is_cuda(), "grads must be a CUDA tensor"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + auto &device_prop = DeviceProp::getDeviceProp(grads.device().index()); + + if (dim % 4 == 0) { + const int max_grid_size = + device_prop.num_sms * + (device_prop.max_thread_per_sm / OPTIMIZER_BLOCKSIZE_VEC); + const int warp_per_block = OPTIMIZER_BLOCKSIZE_VEC / WARPSIZE; + + int grid_size = 0; + if (ev_nums / warp_per_block < max_grid_size) { + grid_size = (ev_nums - 1) / warp_per_block + 1; + } else if (ev_nums / warp_per_block > max_grid_size * MULTIPLIER) { + grid_size = max_grid_size * MULTIPLIER; + } else { + grid_size = max_grid_size; + } + + auto kernel = update4_kernel_fused; + kernel<<>>( + ev_nums, dim, val_dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(values.data_ptr()), nullptr, opt); + } else { + int block_size = dim > OPTIMIZER_BLOCKSIZE ? OPTIMIZER_BLOCKSIZE : dim; + int grid_size = ev_nums; + + auto kernel = update_kernel_fused; + kernel<<>>( + ev_nums, dim, val_dim, reinterpret_cast(grads.data_ptr()), + reinterpret_cast(values.data_ptr()), nullptr, opt); + } + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void dynamic_emb_sgd_fused(at::Tensor grads, at::Tensor values, float const lr) { + + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + auto val_type = + scalartype_to_datatype(convertTypeMetaToScalarType(values.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + SgdVecOptimizer opt{lr}; + fused_update(opt, grads, values); + }); + }); +} + +void dynamic_emb_adam_fused( + at::Tensor grads, at::Tensor values, + const float lr, const float beta1, const float beta2, const float eps, + const float weight_decay, const uint32_t iter_num +) { + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + auto val_type = + scalartype_to_datatype(convertTypeMetaToScalarType(values.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + AdamVecOptimizer opt{lr, + beta1, + beta2, + eps, + weight_decay, + iter_num}; + fused_update(opt, grads, values); + }); + }); +} + +void dynamic_emb_adagrad_fused( + at::Tensor grads, at::Tensor values, + const float lr, const float eps) { + + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + auto val_type = + scalartype_to_datatype(convertTypeMetaToScalarType(values.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + AdaGradVecOptimizer opt{lr, eps}; + fused_update(opt, grads, values); + }); + }); +} + +void dynamic_emb_rowwise_adagrad_fused( + at::Tensor grads, at::Tensor values, + const float lr, const float eps) { + + auto grad_type = + scalartype_to_datatype(convertTypeMetaToScalarType(grads.dtype())); + auto val_type = + scalartype_to_datatype(convertTypeMetaToScalarType(values.dtype())); + + DISPATCH_FLOAT_DATATYPE_FUNCTION(grad_type, g_t, [&] { + DISPATCH_FLOAT_DATATYPE_FUNCTION(val_type, w_t, [&] { + RowWiseAdaGradVecOptimizer opt {lr, eps}; + fused_update(opt, grads, values); + }); + }); +} + } // namespace dyn_emb // PYTHON WRAP @@ -312,24 +656,24 @@ void bind_optimizer_kernel_op(py::module &m) { m.def("dynamic_emb_sgd_with_table", &dyn_emb::dynamic_emb_sgd_with_table, "SGD optimizer for Dynamic Emb", py::arg("table"), py::arg("n"), py::arg("indices"), py::arg("grads"), - py::arg("lr"), py::arg("weight_type"), py::arg("score") = py::none()); + py::arg("lr"), py::arg("weight_type")); m.def("dynamic_emb_adam_with_table", &dyn_emb::dynamic_emb_adam_with_table, "Adam optimizer for Dynamic Emb", py::arg("ht"), py::arg("n"), py::arg("indices"), py::arg("grads"), py::arg("lr"), py::arg("beta1"), py::arg("beta2"), py::arg("eps"), py::arg("weight_decay"), py::arg("iter_num"), - py::arg("weight_type"), py::arg("score") = py::none()); + py::arg("weight_type")); m.def("dynamic_emb_adagrad_with_table", &dyn_emb::dynamic_emb_adagrad_with_table, "Adagrad optimizer for Dynamic Emb", py::arg("ht"), py::arg("n"), py::arg("indices"), py::arg("grads"),py::arg("lr"), py::arg("eps"), - py::arg("weight_type"), py::arg("score") = py::none()); + py::arg("weight_type")); m.def("dynamic_emb_rowwise_adagrad_with_table", &dyn_emb::dynamic_emb_rowwise_adagrad_with_table, "Row Wise Adagrad optimizer for Dynamic Emb", py::arg("ht"), py::arg("n"), py::arg("indices"), py::arg("grads"),py::arg("lr"), py::arg("eps"), - py::arg("weight_type"), py::arg("score") = py::none()); + py::arg("weight_type")); } diff --git a/corelib/dynamicemb/src/optimizer.h b/corelib/dynamicemb/src/optimizer.h index 80cffe55b..5d8a74ec5 100644 --- a/corelib/dynamicemb/src/optimizer.h +++ b/corelib/dynamicemb/src/optimizer.h @@ -38,15 +38,13 @@ namespace dyn_emb { void dynamic_emb_sgd_with_table(std::shared_ptr table, const uint64_t n, const at::Tensor indices, const at::Tensor grads, - const float lr, DataType weight_type, const std::optional score = std::nullopt); + const float lr, DataType weight_type); void dynamic_emb_adam_with_table( std::shared_ptr ht, const uint64_t n, const at::Tensor indices, const at::Tensor grads, const float lr, const float beta1, const float beta2, const float eps, - const float weight_decay, const uint32_t iter_num, DataType weight_type, - const std::optional score = std::nullopt -); + const float weight_decay, const uint32_t iter_num, DataType weight_type); void dynamic_emb_adagrad_with_table( std::shared_ptr ht, @@ -54,7 +52,7 @@ void dynamic_emb_adagrad_with_table( const at::Tensor grads, const float lr, const float eps, - DataType weight_type,const std::optional score = std::nullopt); + DataType weight_type); void dynamic_emb_rowwise_adagrad_with_table( std::shared_ptr ht, @@ -62,7 +60,7 @@ void dynamic_emb_rowwise_adagrad_with_table( const at::Tensor grads, const float lr, const float eps, - DataType weight_type,const std::optional score = std::nullopt); + DataType weight_type); } // namespace dyn_emb #endif // OPTIMIZER_H diff --git a/corelib/dynamicemb/src/optimizer_kernel.cuh b/corelib/dynamicemb/src/optimizer_kernel.cuh index 7ba18fc16..f7c243072 100644 --- a/corelib/dynamicemb/src/optimizer_kernel.cuh +++ b/corelib/dynamicemb/src/optimizer_kernel.cuh @@ -409,10 +409,10 @@ __global__ void update4_kernel(const uint32_t num_keys, const uint32_t dim, cons for (uint32_t ev_id = warp_num_per_block * blockIdx.x + warp_id_in_block; ev_id < num_keys; ev_id += gridDim.x * warp_num_per_block) { - bool mask = masks[ev_id]; + bool mask = masks ? masks[ev_id] : true; weight_t *weight_ptr = weight_evs[ev_id]; const wgrad_t *grad_ptr = grad_evs + ev_id * dim; - if (!mask) { + if ((!mask) or (weight_ptr == nullptr)) { continue; } OptimizierInput input {grad_ptr, weight_ptr, dim}; @@ -426,10 +426,47 @@ __global__ void update_kernel(const uint32_t num_keys, const uint32_t dim, const constexpr int kWarpSize = 32; for (uint32_t ev_id = blockIdx.x; ev_id < num_keys; ev_id += gridDim.x) { - bool mask = masks[ev_id]; + bool mask = masks ? masks[ev_id] : true; weight_t *weight_ptr = weight_evs[ev_id]; const wgrad_t *grad_ptr = grad_evs + ev_id * dim; - if (!mask) { + if ((!mask) or (weight_ptr == nullptr)) { + continue; + } + OptimizierInput input {grad_ptr, weight_ptr, dim}; + optimizer.update(input); + } +} + +template +__global__ void update4_kernel_fused(const uint32_t num_keys, const uint32_t dim, const uint32_t val_dim, const wgrad_t *grad_evs, + weight_t *weight_evs, const bool* masks, OptimizerFunc optimizer) { + constexpr int kWarpSize = 32; + const int warp_num_per_block = blockDim.x / kWarpSize; + const int warp_id_in_block = threadIdx.x / kWarpSize; + + for (uint32_t ev_id = warp_num_per_block * blockIdx.x + warp_id_in_block; + ev_id < num_keys; ev_id += gridDim.x * warp_num_per_block) { + bool mask = masks ? masks[ev_id] : true; + weight_t *weight_ptr = weight_evs + ev_id * val_dim; + const wgrad_t *grad_ptr = grad_evs + ev_id * dim; + if ((!mask) or (weight_ptr == nullptr)) { + continue; + } + OptimizierInput input {grad_ptr, weight_ptr, dim}; + optimizer.update4(input); + } +} + +template +__global__ void update_kernel_fused(const uint32_t num_keys, const uint32_t dim, const uint32_t val_dim, const wgrad_t *grad_evs, + weight_t *weight_evs, const bool* masks, OptimizerFunc optimizer) { + constexpr int kWarpSize = 32; + + for (uint32_t ev_id = blockIdx.x; ev_id < num_keys; ev_id += gridDim.x) { + bool mask = masks ? masks[ev_id] : true; + weight_t *weight_ptr = weight_evs + ev_id * val_dim; + const wgrad_t *grad_ptr = grad_evs + ev_id * dim; + if ((!mask) or (weight_ptr == nullptr)) { continue; } OptimizierInput input {grad_ptr, weight_ptr, dim}; diff --git a/corelib/dynamicemb/src/table_operation/erase.cu b/corelib/dynamicemb/src/table_operation/erase.cu new file mode 100644 index 000000000..9c8db8fe8 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/erase.cu @@ -0,0 +1,61 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "kernels.cuh" +#include "table.cuh" + +namespace dyn_emb { + +void table_erase(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, + at::Tensor keys, std::optional indices) { + + int64_t num_total = keys.size(0); + if (num_total == 0) + return; + + auto key_type = get_data_type(keys); + auto bucket_sizes_ = get_pointer(bucket_sizes); + auto indices_ = get_pointer(indices); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + constexpr int BLOCK_SIZE = 256; + + DISPATCH_KEY_TYPE(key_type, KeyType, [&] { + auto keys_ = get_pointer(keys); + + constexpr int64_t total_size = + sizeof(KeyType) + sizeof(DigestType) + sizeof(ScoreType); + int64_t bucket_bytes = bucket_capacity * total_size; + int64_t num_buckets = + table_storage.numel() * table_storage.element_size() / bucket_bytes; + + using Bucket = LinearBucket; + using Table = LinearBucketTable; + + auto table = Table(reinterpret_cast(table_storage.data_ptr()), + num_buckets, bucket_capacity); + + table_erase_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table, bucket_sizes_, num_total, keys_, indices_); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/export_batch.cu b/corelib/dynamicemb/src/table_operation/export_batch.cu new file mode 100644 index 000000000..ebbba10d8 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/export_batch.cu @@ -0,0 +1,91 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "kernels.cuh" +#include "table.cuh" + +namespace dyn_emb { + +void table_export_single_score(at::Tensor table_storage, + std::vector dtypes, + int64_t bucket_capacity, int64_t batch, + int64_t offset, at::Tensor counter, + at::Tensor keys, + std::vector> scores, + std::optional indices) { + auto key_type = get_data_type(keys); + auto scores_ = get_pointer(scores[0]); + auto indices_ = get_pointer(indices); + auto counter_ = get_pointer(counter); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + int64_t num_total = batch; + + constexpr int BLOCK_SIZE = 256; + + DISPATCH_KEY_TYPE(key_type, KeyType, [&] { + auto keys_ = get_pointer(keys); + + constexpr int64_t total_size = + sizeof(KeyType) + sizeof(DigestType) + sizeof(ScoreType); + int64_t bucket_bytes = bucket_capacity * total_size; + int64_t num_buckets = + table_storage.numel() * table_storage.element_size() / bucket_bytes; + + using Bucket = LinearBucket; + using Table = LinearBucketTable; + + auto table = Table(reinterpret_cast(table_storage.data_ptr()), + num_buckets, bucket_capacity); + + if (offset + num_total > num_buckets * bucket_capacity) { + throw std::invalid_argument("Offset and batch size overflow."); + } + + if (num_total % 32 == 0) { + table_export_batch_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, + stream>>>(table, offset, offset + num_total, counter_, keys_, + scores_, indices_); + } else { + table_export_batch_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, + stream>>>(table, offset, offset + num_total, counter_, keys_, + scores_, indices_); + } + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void table_export_batch(at::Tensor table_storage, + std::vector dtypes, + int64_t bucket_capacity, int64_t batch, int64_t offset, + at::Tensor counter, at::Tensor keys, + std::vector> scores, + std::optional indices) { + if (batch == 0) + return; + + if (scores.size() == 1) { + table_export_single_score(table_storage, dtypes, bucket_capacity, batch, + offset, counter, keys, scores, indices); + } else { + throw std::runtime_error("Not support multi-scores."); + } +} +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/insert.cu b/corelib/dynamicemb/src/table_operation/insert.cu new file mode 100644 index 000000000..be6f47553 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/insert.cu @@ -0,0 +1,102 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "kernels.cuh" +#include "table.cuh" + +namespace dyn_emb { + +void table_insert_single_score(at::Tensor table_storage, + std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, + at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, + std::optional indices, + std::optional insert_results) { + + auto key_type = get_data_type(keys); + + bool is_return = is_returns[0]; + ScorePolicyType policy_type = policy_types[0]; + auto scores_ = get_pointer(scores[0]); + auto indices_ = get_pointer(indices); + auto insert_results_ = get_pointer(insert_results); + auto bucket_sizes_ = get_pointer(bucket_sizes); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + int64_t num_total = keys.size(0); + + auto table_key_slots = at::zeros( + num_total, at::TensorOptions().dtype(at::kLong).device(keys.device())); + + constexpr int BLOCK_SIZE = 256; + + DISPATCH_KEY_TYPE(key_type, KeyType, [&] { + auto keys_ = get_pointer(keys); + auto table_key_slots_ = get_pointer(table_key_slots); + + constexpr int64_t total_size = + sizeof(KeyType) + sizeof(DigestType) + sizeof(ScoreType); + int64_t bucket_bytes = bucket_capacity * total_size; + int64_t num_buckets = + table_storage.numel() * table_storage.element_size() / bucket_bytes; + + using Bucket = LinearBucket; + using Table = LinearBucketTable; + + auto table = Table(reinterpret_cast(table_storage.data_ptr()), + num_buckets, bucket_capacity); + + using KernelTraits = InsertKernelTraits; + + table_insert_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table, bucket_sizes_, num_total, keys_, insert_results_, indices_, + scores_, policy_type, is_return, table_key_slots_); + + table_unlock_kernel
+ <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table, num_total, keys_, table_key_slots_); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void table_insert(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, + at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, + std::optional indices, + std::optional insert_results) { + + int64_t num_total = keys.size(0); + if (num_total == 0) + return; + if (scores.size() == 1) { + table_insert_single_score(table_storage, dtypes, bucket_capacity, + bucket_sizes, keys, scores, policy_types, + is_returns, indices, insert_results); + } else { + throw std::runtime_error("Not support multi-scores."); + } +} + +} // namespace dyn_emb diff --git a/corelib/dynamicemb/src/table_operation/insert_and_evict.cu b/corelib/dynamicemb/src/table_operation/insert_and_evict.cu new file mode 100644 index 000000000..bda7b21aa --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/insert_and_evict.cu @@ -0,0 +1,119 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "kernels.cuh" +#include "table.cuh" + +namespace dyn_emb { + +void table_insert_and_evict_single_score( + at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, at::Tensor keys, + std::vector> scores, + std::vector policy_types, std::vector is_returns, + std::optional insert_results, std::optional indices, + at::Tensor num_evicted, at::Tensor evicted_keys, at::Tensor evicted_indices, + std::vector evicted_scores) { + + auto key_type = get_data_type(keys); + + bool is_return = is_returns[0]; + ScorePolicyType policy_type = policy_types[0]; + auto scores_ = get_pointer(scores[0]); + auto indices_ = get_pointer(indices); + auto insert_results_ = get_pointer(insert_results); + auto bucket_sizes_ = get_pointer(bucket_sizes); + + auto evict_counter_ = get_pointer(num_evicted); + auto evicted_scores_ = get_pointer(evicted_scores[0]); + auto evicted_indices_ = get_pointer(evicted_indices); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + int64_t num_total = keys.size(0); + + auto table_key_slots = at::zeros( + num_total, at::TensorOptions().dtype(at::kLong).device(keys.device())); + + constexpr int BLOCK_SIZE = 256; + + DISPATCH_KEY_TYPE(key_type, KeyType, [&] { + auto keys_ = get_pointer(keys); + auto evicted_keys_ = get_pointer(evicted_keys); + auto table_key_slots_ = get_pointer(table_key_slots); + + constexpr int64_t total_size = + sizeof(KeyType) + sizeof(DigestType) + sizeof(ScoreType); + int64_t bucket_bytes = bucket_capacity * total_size; + int64_t num_buckets = + table_storage.numel() * table_storage.element_size() / bucket_bytes; + + using Bucket = LinearBucket; + using Table = LinearBucketTable; + + auto table = Table(reinterpret_cast(table_storage.data_ptr()), + num_buckets, bucket_capacity); + + if (num_total % 32 == 0) { + using KernelTraits = InsertKernelTraits; + table_insert_and_evict_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, + stream>>>(table, bucket_sizes_, num_total, keys_, insert_results_, + indices_, scores_, policy_type, is_return, + table_key_slots_, evict_counter_, evicted_keys_, + evicted_scores_, evicted_indices_); + } else { + using KernelTraits = InsertKernelTraits; + table_insert_and_evict_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, + stream>>>(table, bucket_sizes_, num_total, keys_, insert_results_, + indices_, scores_, policy_type, is_return, + table_key_slots_, evict_counter_, evicted_keys_, + evicted_scores_, evicted_indices_); + } + + table_unlock_kernel
+ <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table, num_total, keys_, table_key_slots_); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void table_insert_and_evict( + at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, at::Tensor keys, + std::vector> scores, + std::vector policy_types, std::vector is_returns, + std::optional insert_results, std::optional indices, + at::Tensor num_evicted, at::Tensor evicted_keys, at::Tensor evicted_indices, + std::vector evicted_scores) { + + int64_t num_total = keys.size(0); + if (num_total == 0) + return; + + if (scores.size() == 1) { + table_insert_and_evict_single_score( + table_storage, dtypes, bucket_capacity, bucket_sizes, keys, scores, + policy_types, is_returns, insert_results, indices, num_evicted, + evicted_keys, evicted_indices, evicted_scores); + } else { + throw std::runtime_error("Not support multi-scores."); + } +} + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/kernels.cuh b/corelib/dynamicemb/src/table_operation/kernels.cuh new file mode 100644 index 000000000..93455a261 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/kernels.cuh @@ -0,0 +1,509 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#pragma once + +#include "types.cuh" +#include + +#include + +using namespace cooperative_groups; +namespace cg = cooperative_groups; + +namespace dyn_emb { + +template +struct InsertKernelTraits { + static constexpr int ThreadBlockDim = ThreadBlockDim_; + static constexpr int ProbingGroupSize = ProbingGroupSize_; + static constexpr int ReductionGroupSize = ReductionGroupSize_; + static constexpr int CompactTileSize = CompactTileSize_; + static constexpr int NumScorePerThread = NumScorePerThread_; +}; + +template +__global__ void +table_lookup_kernel(Table table, int64_t batch, + typename Table::KeyType const *__restrict__ input_keys, + bool *__restrict__ founds, IndexType *__restrict__ indices, + ScoreType *__restrict__ scores, ScorePolicyType policy_type, + bool return_scores) { + + using KeyType = typename Table::KeyType; + using Bucket = typename Table::BucketType; + using Iter = typename Bucket::Iterator; + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + for (int64_t i = tid; i < batch; i += gridDim.x * blockDim.x) { + + KeyType key = input_keys[i]; + ScoreType score = ScorePolicy::get(policy_type, scores, i); + + Bucket bucket; + KeyType hashcode = KeyType(); + int64_t bucket_id; + if (Bucket::is_valid(key)) { + hashcode = Table::hash(key); + uint64_t global_idx = static_cast(hashcode % table.capacity()); + bucket_id = global_idx / table.bucket_capacity(); + // bucket_id = (hashcode % table.capacity()) / table.bucket_capacity(); + bucket = table[bucket_id]; + } + Iter iter = Iter(hashcode % table.bucket_capacity()); + int step = 0; + auto probe_res = bucket.probe(key, iter, step); + bool found = probe_res == Bucket::ProbeResult::Existed; + IndexType index = -1; + if (found) { + + if (policy_type == ScorePolicyType::Const) { + score = *bucket.scores(iter); + } else { + KeyType expected_key = key; + if (bucket.try_lock(iter, expected_key)) { + ScorePolicy::update(policy_type, return_scores, bucket.scores(iter), + score); + bucket.unlock(iter, key); + } else { + found = false; // only one update will succeed for duplicated keys. + score = ScoreType(); + } + } + + if (found) { + index = bucket_id * bucket.capacity() + iter; + } + } + ScorePolicy::set(return_scores, scores, i, score); + if (founds) { + founds[i] = found; + } + if (indices) { + indices[i] = index; + } + } +} + +template +__global__ void +table_insert_kernel(Table table, int *__restrict__ bucket_sizes, int64_t batch, + typename Table::KeyType const *__restrict__ input_keys, + InsertResult *__restrict__ insert_results, + IndexType *__restrict__ indices, + ScoreType *__restrict__ scores, ScorePolicyType policy_type, + bool return_scores, + typename Table::KeyType **__restrict__ table_key_slots) { + + using KeyType = typename Table::KeyType; + using Bucket = typename Table::BucketType; + using Iter = typename Bucket::Iterator; + using ProbeResult = typename Bucket::ProbeResult; + + static constexpr int BlockSize = KernelTraits::ThreadBlockDim; + static constexpr int BufferDim = KernelTraits::NumScorePerThread; + + static constexpr int ProbingGroupSize = KernelTraits::ProbingGroupSize; + static constexpr int ReductionGroupSize = KernelTraits::ReductionGroupSize; + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + __shared__ ScoreType sm_scores[BlockSize * BufferDim]; + // extern __shared__ ScoreType sm_scores[]; + // cuda::pipeline pipe = cuda::make_pipeline(); + + for (int64_t i = tid; i < batch; i += gridDim.x * blockDim.x) { + + KeyType key = input_keys[i]; + ScoreType score = ScorePolicy::get(policy_type, scores, i); + + InsertResult result = InsertResult::Init; + + Bucket bucket; + KeyType hashcode = KeyType(); + uint64_t bucket_id; + if (Bucket::is_valid(key)) { + hashcode = Table::hash(key); + uint64_t global_idx = static_cast(hashcode % table.capacity()); + bucket_id = global_idx / table.bucket_capacity(); + // bucket_id = (hashcode % table.capacity()) / table.bucket_capacity(); + bucket = table[bucket_id]; + } + Iter iter = Iter(hashcode % table.bucket_capacity()); + ProbeResult probe_res = ProbeResult::Init; + int step = 0; + while (step != bucket.capacity()) { + probe_res = bucket.probe(key, iter, step); + if (probe_res == ProbeResult::Existed) { + KeyType expected_key = key; + + if (bucket.try_lock(iter, expected_key)) { + result = InsertResult::Assign; + // bucket.unlock(iter, key); // will not unlock, to avoid 2 threads + // got the same slot. + } // else: the key is evicted from the bucket(full), try to reintert by + // eviction including reclaimed key. + break; + } + if (probe_res == ProbeResult::Empty) { + KeyType expected_key = Bucket::empty_key(); + + if (bucket.try_lock(iter, expected_key)) { + *bucket.digests(iter) = Bucket::key_to_digest(key); + atomicAdd(&bucket_sizes[bucket_id], 1); + result = InsertResult::Insert; + break; + } // else it was locked by another thread. + } + } + + while (result == InsertResult::Init) { + + KeyType evict_key; + ScoreType evict_score = + ScorePolicy::score_for_compare(policy_type, score); + + bool succeed = bucket.template reduce( + iter, evict_key, evict_score, sm_scores); + + if (succeed) { + + if (bucket.try_lock(iter, evict_key)) { + if (*bucket.scores(iter) != evict_score) { + // that means when reduce we got a new key but old score. + bucket.unlock(iter, evict_key); + } else { + *bucket.digests(iter) = Bucket::key_to_digest(key); + if (evict_key == Bucket::reclaimed_key()) { + atomicAdd(bucket_sizes + bucket_id, 1); + result = InsertResult::Reclaim; + } else { + *bucket.scores(iter) = ScoreType(); + result = InsertResult::Evict; + } + break; + } + } // else it was locked by another thread. + } else { + result = InsertResult::Busy; + break; + } + } + + IndexType index = -1; + KeyType *table_key_slot = nullptr; + if (isInsertSuccess(result)) { + ScorePolicy::update(policy_type, return_scores, bucket.scores(iter), + score); + index = bucket_id * bucket.capacity() + iter; + table_key_slot = bucket.keys(iter); + } + ScorePolicy::set(return_scores, scores, i, score); + //TODO: unlock using index. + table_key_slots[i] = table_key_slot; + if (indices) { + indices[i] = index; + } + if (insert_results) { + insert_results[i] = result; + } + } +} + +template +__global__ void table_insert_and_evict_kernel( + Table table, int *__restrict__ bucket_sizes, int64_t batch, + typename Table::KeyType const *__restrict__ input_keys, + InsertResult *__restrict__ insert_results, IndexType *__restrict__ indices, + ScoreType *__restrict__ scores, ScorePolicyType policy_type, + bool return_scores, typename Table::KeyType **__restrict__ table_key_slots, + CounterType *evicted_counter, + typename Table::KeyType *__restrict__ evicted_keys, + ScoreType *__restrict__ evicted_scores, + IndexType *__restrict__ evicted_indices) { + + using KeyType = typename Table::KeyType; + using Bucket = typename Table::BucketType; + using Iter = typename Bucket::Iterator; + using ProbeResult = typename Bucket::ProbeResult; + + static constexpr int BlockSize = KernelTraits::ThreadBlockDim; + static constexpr int BufferDim = KernelTraits::NumScorePerThread; + + static constexpr int ProbingGroupSize = KernelTraits::ProbingGroupSize; + static constexpr int ReductionGroupSize = KernelTraits::ReductionGroupSize; + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + __shared__ ScoreType sm_scores[BlockSize * BufferDim]; + + for (int64_t i = tid; i < batch; i += gridDim.x * blockDim.x) { + + KeyType key = input_keys[i]; + ScoreType score = ScorePolicy::get(policy_type, scores, i); + + InsertResult result = InsertResult::Init; + + Bucket bucket; + KeyType hashcode = KeyType(); + int64_t bucket_id; + if (Bucket::is_valid(key)) { + hashcode = Table::hash(key); + uint64_t global_idx = static_cast(hashcode % table.capacity()); + bucket_id = global_idx / table.bucket_capacity(); + // bucket_id = (hashcode % table.capacity()) / table.bucket_capacity(); + bucket = table[bucket_id]; + } + Iter iter = Iter(hashcode % table.bucket_capacity()); + ProbeResult probe_res = ProbeResult::Init; + int step = 0; + while (step != bucket.capacity()) { + probe_res = bucket.probe(key, iter, step); + if (probe_res == ProbeResult::Existed) { + KeyType expected_key = key; + if (bucket.try_lock(iter, expected_key)) { + result = InsertResult::Assign; + // bucket.unlock(iter, key); // will not unlock, to avoid 2 threads + // got the same slot. + } // else: the key is evicted from the bucket(full), try to reintert by + // eviction including reclaimed key. + break; + } + if (probe_res == ProbeResult::Empty) { + KeyType expected_key = Bucket::empty_key(); + if (bucket.try_lock(iter, expected_key)) { + *bucket.digests(iter) = Bucket::key_to_digest(key); + atomicAdd(&bucket_sizes[bucket_id], 1); + result = InsertResult::Insert; + break; + } // else it was locked by another thread. + } + } + + KeyType evict_key; + ScoreType evict_score; + + while (result == InsertResult::Init) { + + evict_score = ScorePolicy::score_for_compare(policy_type, score); + bool succeed = bucket.template reduce( + iter, evict_key, evict_score, sm_scores); + + if (succeed) { + + if (bucket.try_lock(iter, evict_key)) { + if (*bucket.scores(iter) != evict_score) { + // that means when reduce we got a new key but old score. + bucket.unlock(iter, evict_key); + } else { + *bucket.digests(iter) = Bucket::key_to_digest(key); + if (evict_key == Bucket::reclaimed_key()) { + atomicAdd(&bucket_sizes[bucket_id], 1); + result = InsertResult::Reclaim; + } else { + *bucket.scores(iter) = ScoreType(); + result = InsertResult::Evict; + } + break; + } + } // else it was locked by another thread. + } else { + result = InsertResult::Busy; + evict_key = key; + evict_score = score; + break; + } + } + + auto g = cg::tiled_partition( + cg::this_thread_block()); + bool evicted = + (result == InsertResult::Evict or result == InsertResult::Busy) ? true + : false; + uint32_t vote = g.ballot(evicted); + int group_cnt = __popc(vote); + CounterType group_offset = 0; + if (g.thread_rank() == 0) { + group_offset = + atomicAdd(evicted_counter, static_cast(group_cnt)); + } + group_offset = g.shfl(group_offset, 0); + + int previous_cnt = group_cnt - __popc(vote >> g.thread_rank()); + int64_t out_id = group_offset + previous_cnt; + + if (evicted) { + evicted_keys[out_id] = evict_key; + if (evicted_scores) { + evicted_scores[out_id] = evict_score; + } + if (evicted_indices) { + IndexType index; + if (result == InsertResult::Evict) { + index = bucket_id * bucket.capacity() + iter; + } else { + index = -1; + } + evicted_indices[out_id] = index; + } + } + + IndexType index = -1; + KeyType *table_key_slot = nullptr; + if (isInsertSuccess(result)) { + ScorePolicy::update(policy_type, return_scores, bucket.scores(iter), + score); + index = bucket_id * bucket.capacity() + iter; + table_key_slot = bucket.keys(iter); + } + ScorePolicy::set(return_scores, scores, i, score); + table_key_slots[i] = table_key_slot; + if (indices) { + indices[i] = index; + } + if (insert_results) { + insert_results[i] = result; + } + } +} + +template +__global__ void +table_unlock_kernel(Table table, int64_t batch, + typename Table::KeyType const *__restrict__ input_keys, + typename Table::KeyType **__restrict__ table_key_slots) { + using KeyType = typename Table::KeyType; + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + for (int64_t i = tid; i < batch; i += gridDim.x * blockDim.x) { + KeyType key = input_keys[i]; + KeyType *key_slot = table_key_slots[i]; + if (key_slot) { + *key_slot = key; + } + } +} + +template +__global__ void +table_erase_kernel(Table table, int *__restrict__ bucket_sizes, int64_t batch, + typename Table::KeyType const *__restrict__ input_keys, + IndexType *__restrict__ indices) { + + using KeyType = typename Table::KeyType; + using Bucket = typename Table::BucketType; + using Iter = typename Bucket::Iterator; + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + for (int64_t i = tid; i < batch; i += gridDim.x * blockDim.x) { + + KeyType key = input_keys[i]; + + Bucket bucket; + KeyType hashcode = KeyType(); + int64_t bucket_id; + if (Bucket::is_valid(key)) { + hashcode = Table::hash(key); + uint64_t global_idx = static_cast(hashcode % table.capacity()); + bucket_id = global_idx / table.bucket_capacity(); + // bucket_id = (hashcode % table.capacity()) / table.bucket_capacity(); + bucket = table[bucket_id]; + } + Iter iter = Iter(hashcode % table.bucket_capacity()); + int step = 0; + auto probe_res = bucket.probe(key, iter, step); + bool found = probe_res == Bucket::ProbeResult::Existed; + IndexType index = -1; + if (found) { + + KeyType expected_key = key; + if (bucket.try_lock(iter, expected_key)) { + *bucket.scores(iter) = ScoreType(); + *bucket.digests(iter) = Bucket::empty_digest(); + + bucket.unlock(iter, Bucket::reclaimed_key()); + atomicSub(bucket_sizes + bucket_id, 1); + } else { + found = false; // only one update will succeed for duplicated keys. + } + + if (found) { + index = bucket_id * bucket.capacity() + iter; + } + } + if (indices) { + indices[i] = index; + } + } +} + +template +__global__ void +table_export_batch_kernel(Table table, IndexType begin, IndexType end, + CounterType *__restrict__ counter, + typename Table::KeyType *__restrict__ keys, + ScoreType *__restrict__ scores, + IndexType *__restrict__ indices) { + using KeyType = typename Table::KeyType; + using Bucket = typename Table::BucketType; + using Iter = typename Bucket::Iterator; + + auto g = cg::tiled_partition(cg::this_thread_block()); + + auto tid = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + + for (int64_t i = begin + tid; i < end; i += gridDim.x * blockDim.x) { + + int64_t bucket_id = i / table.bucket_capacity(); + + Bucket bucket = table[bucket_id]; + + Iter iter = Iter(i % bucket.capacity()); + + const KeyType key = *bucket.keys(iter); + const ScoreType score = *bucket.scores(iter); + const IndexType index = i; + + bool valid = Bucket::is_valid(key); + uint32_t vote = g.ballot(valid); + int group_cnt = __popc(vote); + CounterType group_offset = 0; + if (g.thread_rank() == 0) { + group_offset = atomicAdd(counter, static_cast(group_cnt)); + } + group_offset = g.shfl(group_offset, 0); + + int previous_cnt = group_cnt - __popc(vote >> g.thread_rank()); + int64_t out_id = group_offset + previous_cnt; + + if (valid) { + keys[out_id] = key; + if (scores) { + scores[out_id] = score; + } + if (indices) { + indices[out_id] = index; + } + } + } +} + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/lookup.cu b/corelib/dynamicemb/src/table_operation/lookup.cu new file mode 100644 index 000000000..3c3b20055 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/lookup.cu @@ -0,0 +1,88 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "kernels.cuh" +#include "table.cuh" + +namespace dyn_emb { + +void table_lookup_single_score(at::Tensor table_storage, + std::vector dtypes, + int64_t bucket_capacity, at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, at::Tensor founds, + std::optional indices) { + + auto key_type = get_data_type(keys); + + bool is_return = is_returns[0]; + ScorePolicyType policy_type = policy_types[0]; + auto scores_ = get_pointer(scores[0]); + auto indices_ = get_pointer(indices); + auto founds_ = founds.data_ptr(); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + int64_t num_total = keys.size(0); + + constexpr int BLOCK_SIZE = 256; + + DISPATCH_KEY_TYPE(key_type, KeyType, [&] { + auto keys_ = get_pointer(keys); + + constexpr int64_t total_size = + sizeof(KeyType) + sizeof(DigestType) + sizeof(ScoreType); + int64_t bucket_bytes = bucket_capacity * total_size; + int64_t num_buckets = + table_storage.numel() * table_storage.element_size() / bucket_bytes; + + using Bucket = LinearBucket; + using Table = LinearBucketTable; + + auto table = Table(reinterpret_cast(table_storage.data_ptr()), + num_buckets, bucket_capacity); + + table_lookup_kernel + <<<(num_total + BLOCK_SIZE - 1) / BLOCK_SIZE, BLOCK_SIZE, 0, stream>>>( + table, num_total, keys_, founds_, indices_, scores_, policy_type, + is_return); + }); + DEMB_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void table_lookup(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, at::Tensor founds, + std::optional indices) { + + int64_t num_total = keys.size(0); + if (num_total == 0) + return; + + if (scores.size() == 1) { + table_lookup_single_score(table_storage, dtypes, bucket_capacity, keys, + scores, policy_types, is_returns, founds, + indices); + } else { + throw std::runtime_error("Not support multi-scores."); + } +} + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/score.cuh b/corelib/dynamicemb/src/table_operation/score.cuh new file mode 100644 index 000000000..7bb2e8970 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/score.cuh @@ -0,0 +1,87 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +namespace dyn_emb { + +using ScoreType = uint64_t; + +enum class ScorePolicyType : uint8_t { + Const = 0, + Assign = 1, + Accumulate = 2, + GlobalTimer = 3, +}; + +struct ScorePolicy { + + static __device__ __forceinline__ ScoreType get(ScorePolicyType policy_type, + ScoreType *scores, + int64_t index) { + + if (policy_type == ScorePolicyType::Const) { + return ScoreType(); + } + if (policy_type == ScorePolicyType::GlobalTimer) { + ScoreType score; + asm volatile("mov.u64 %0,%%globaltimer;" : "=l"(score)); + return score; + } else { + return scores[index]; + } + } + + static __device__ __forceinline__ ScoreType + score_for_compare(ScorePolicyType policy_type, ScoreType score) { + return UINT64_MAX; + } + + static __device__ __forceinline__ void update(ScorePolicyType policy_type, + bool is_return, + ScoreType *table_score, + ScoreType &score) { + + if (policy_type == ScorePolicyType::Const) { + if (is_return) { + score = *table_score; + } + return; + } + if (policy_type == ScorePolicyType::Accumulate) { + score += *table_score; + *table_score = score; + } else { + *table_score = score; + } + } + + static __device__ __forceinline__ void set(bool is_return, ScoreType *scores, + int64_t index, ScoreType score) { + if (is_return) { + scores[index] = score; + } + } +}; + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/table_operation/table.cu b/corelib/dynamicemb/src/table_operation/table.cu new file mode 100644 index 000000000..3a29c3d98 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/table.cu @@ -0,0 +1,144 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "table.cuh" + +namespace dyn_emb { + +std::vector table_partition(at::Tensor storage, + std::vector dtypes, + int64_t bucket_capacity, + int64_t num_buckets) { + + int64_t num_types = dtypes.size(); + std::vector result; + result.reserve(num_types); + + std::vector bytes_offset; + std::vector bucket_bytes_offset; + + bytes_offset.reserve(num_types + 1); + bucket_bytes_offset.reserve(num_types + 1); + + bytes_offset.push_back(0); + bucket_bytes_offset.push_back(0); + + for (int64_t i = 0; i < num_types; i++) { + auto scalar_type = static_cast(dtypes[i]); + auto dtype_bytes = get_size(scalar_type); + auto offset = bytes_offset.back() + dtype_bytes; + bytes_offset.push_back(offset); + + auto array_bytes = dtype_bytes * bucket_capacity; + auto array_offset = bucket_bytes_offset.back() + array_bytes; + bucket_bytes_offset.push_back(array_offset); + } + + int64_t bucket_bytes = bucket_bytes_offset.back(); + if (bucket_bytes * num_buckets != storage.numel() * storage.element_size()) { + throw std::runtime_error( + "Storage size mismatched with bucket_bytes * num_buckets"); + } + + for (int64_t i = 0; i < num_types; i++) { + int64_t stride = bucket_bytes / (bytes_offset[i + 1] - bytes_offset[i]); + void *raw_data = storage.data_ptr() + bucket_capacity * bytes_offset[i]; + result.push_back(at::from_blob(raw_data, {num_buckets, bucket_capacity}, + {stride, 1}, + storage.options().dtype(dtypes[i]))); + } + return result; +} + +std::vector tensor_partition(at::Tensor input, + std::vector byte_range, + std::vector dtypes) { + int num_partition = byte_range.size() - 1; + std::vector result; + result.reserve(num_partition); + for (int i = 0; i < num_partition; i++) { + auto raw_data = input.data_ptr() + byte_range[i]; + int64_t partition_size = byte_range[i + 1] - byte_range[i]; + auto scalar_type = static_cast(dtypes[i]); + partition_size = partition_size / get_size(scalar_type); + result.push_back(at::from_blob(raw_data, {partition_size}, + input.options().dtype(dtypes[i]))); + } + return result; +} + +} // namespace dyn_emb + +namespace py = pybind11; + +void bind_table_operation(py::module &m) { + + m.def("tensor_partition", &dyn_emb::tensor_partition, + "split the tensor into several sub-partitions.", py::arg("input"), + py::arg("byte_range"), py::arg("dtypes")); + + m.def("table_partition", &dyn_emb::table_partition, + "split the tensor into several sub-partitions.", py::arg("storage"), + py::arg("dtypes"), py::arg("bucket_capacity"), py::arg("num_buckets")); + + m.def("table_lookup", &dyn_emb::table_lookup, "lookup the table", + py::arg("table_storage"), py::arg("dtypes"), py::arg("bucket_capacity"), + py::arg("keys"), py::arg("scores"), py::arg("policy_types"), + py::arg("is_returns"), py::arg("founds"), py::arg("indices")); + + m.def("table_insert", &dyn_emb::table_insert, "insert into the table", + py::arg("table_storage"), py::arg("dtypes"), py::arg("bucket_capacity"), + py::arg("bucket_sizes"), py::arg("keys"), py::arg("scores"), + py::arg("policy_types"), py::arg("is_returns"), py::arg("indices"), + py::arg("insert_results")); + + m.def("table_insert_and_evict", &dyn_emb::table_insert_and_evict, + "insert into the table", py::arg("table_storage"), py::arg("dtypes"), + py::arg("bucket_capacity"), py::arg("bucket_sizes"), py::arg("keys"), + py::arg("scores"), py::arg("policy_types"), py::arg("is_returns"), + py::arg("insert_results"), py::arg("indices"), py::arg("num_evicted"), + py::arg("evicted_keys"), py::arg("evicted_indices"), + py::arg("evicted_scores")); + + m.def("table_erase", &dyn_emb::table_erase, "erase keys from the table", + py::arg("table_storage"), py::arg("dtypes"), py::arg("bucket_capacity"), + py::arg("bucket_sizes"), py::arg("keys"), + py::arg("indices") = py::none()); + + m.def("table_export_batch", &dyn_emb::table_export_batch, + "erase items[offset, offset + batch) from the table", + py::arg("table_storage"), py::arg("dtypes"), py::arg("bucket_capacity"), + py::arg("batch"), py::arg("offset"), py::arg("counter"), + py::arg("keys"), py::arg("scores"), py::arg("indices") = py::none()); + + py::enum_(m, "ScorePolicy") + .value("CONST", dyn_emb::ScorePolicyType::Const) + .value("ASSIGN", dyn_emb::ScorePolicyType::Assign) + .value("ACCUMULATE", dyn_emb::ScorePolicyType::Accumulate) + .value("GLOBAL_TIMER", dyn_emb::ScorePolicyType::GlobalTimer) + .export_values(); + + py::enum_(m, "InsertResult") + .value("INSERT", dyn_emb::InsertResult::Insert) + .value("RECLAIM", dyn_emb::InsertResult::Reclaim) + .value("ASSIGN", dyn_emb::InsertResult::Assign) + .value("EVICT", dyn_emb::InsertResult::Evict) + .value("DUPLICATED", dyn_emb::InsertResult::Duplicated) + .value("BUSY", dyn_emb::InsertResult::Busy) + .value("INIT", dyn_emb::InsertResult::Init) + .export_values(); +} diff --git a/corelib/dynamicemb/src/table_operation/table.cuh b/corelib/dynamicemb/src/table_operation/table.cuh new file mode 100644 index 000000000..0bcdf7c51 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/table.cuh @@ -0,0 +1,140 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#include "../check.h" +#include "torch_utils.h" +#include "types.cuh" + +#include +#include +#include + +#include + +#include +#include + +#include +#include + +#define DISPATCH_KEY_TYPE(DATA_TYPE, HINT, ...) \ + switch (DATA_TYPE) { \ + CASE_TYPE_USING_HINT(DataType::Int64, int64_t, HINT, __VA_ARGS__) \ + CASE_TYPE_USING_HINT(DataType::UInt64, uint64_t, HINT, __VA_ARGS__) \ + default: \ + throw std::runtime_error("Not supported key type."); \ + } + +#define DISPATCH_SCORE_TYPE(DATA_TYPE, HINT, ...) \ + switch (DATA_TYPE) { \ + CASE_TYPE_USING_HINT(DataType::UInt64, uint64_t, HINT, __VA_ARGS__) \ + CASE_TYPE_USING_HINT(DataType::UInt32, uint32_t, HINT, __VA_ARGS__) \ + default: \ + throw std::runtime_error("Not supported score type."); \ + } + +#define DISPATCH_SCORE_POLICY(SCORE_POLICY, HINT, ...) \ + switch (SCORE_POLICY) { \ + CASE_ENUM_USING_HINT(ScorePolicyType::Const, HINT, __VA_ARGS__) \ + CASE_ENUM_USING_HINT(ScorePolicyType::Assign, HINT, __VA_ARGS__) \ + CASE_ENUM_USING_HINT(ScorePolicyType::Accumulate, HINT, __VA_ARGS__) \ + CASE_ENUM_USING_HINT(ScorePolicyType::GlobalTimer, HINT, __VA_ARGS__) \ + default: \ + throw std::runtime_error("Not supported score policy."); \ + } + +namespace dyn_emb { + +inline int get_size(torch::ScalarType scalar_type) { + switch (scalar_type) { + case torch::kUInt8: + return 1; + case torch::kInt8: + return 1; + case torch::kInt16: + return 2; + case torch::kInt32: + return 4; + case torch::kInt64: + return 8; + case torch::kFloat32: + return 4; + case torch::kFloat64: + return 8; + case torch::kBool: + return 1; + case torch::kBFloat16: + return 2; + case torch::kFloat16: + return 2; + case torch::kUInt16: + return 2; + case torch::kUInt32: + return 4; + case torch::kUInt64: + return 8; + default: + throw std::runtime_error("Unsupported scalar type."); + } +} + +void table_lookup(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, at::Tensor founds, + std::optional indices); + +void table_insert(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, + at::Tensor keys, + std::vector> scores, + std::vector policy_types, + std::vector is_returns, + std::optional indices, + std::optional insert_results); + +void table_insert_and_evict( + at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, at::Tensor keys, + std::vector> scores, + std::vector policy_types, std::vector is_returns, + std::optional insert_results, std::optional indices, + at::Tensor num_evicted, at::Tensor evicted_keys, at::Tensor evicted_indices, + std::vector evicted_scores); + +void table_erase(at::Tensor table_storage, std::vector dtypes, + int64_t bucket_capacity, at::Tensor bucket_sizes, + at::Tensor keys, std::optional indices); + +void table_export_batch(at::Tensor table_storage, + std::vector dtypes, + int64_t bucket_capacity, int64_t batch, int64_t offset, + at::Tensor counter, at::Tensor keys, + std::vector> scores, + std::optional indices); + +std::vector table_partition(at::Tensor storage, + std::vector dtypes, + int64_t bucket_capacity, + int64_t num_buckets); + +std::vector tensor_partition(at::Tensor input, + std::vector byte_range, + std::vector dtypes); + +} // namespace dyn_emb diff --git a/corelib/dynamicemb/src/table_operation/types.cuh b/corelib/dynamicemb/src/table_operation/types.cuh new file mode 100644 index 000000000..df5fa0936 --- /dev/null +++ b/corelib/dynamicemb/src/table_operation/types.cuh @@ -0,0 +1,484 @@ +/****************************************************************************** +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +All rights reserved. # SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "score.cuh" + +extern "C" __device__ size_t __cvta_generic_to_shared(const void *); + +namespace dyn_emb { + +using CounterType = int64_t; +using DigestType = uint8_t; +using IndexType = int64_t; + +__forceinline__ __device__ int atomicAdd(int *address, int val) { + return ::atomicAdd(address, val); +} + +__device__ __forceinline__ CounterType atomicAdd(CounterType *address, + CounterType val) { + return (CounterType)::atomicAdd((unsigned long long *)address, + (unsigned long long)val); +} + +enum class InsertResult : uint8_t { + Insert, // Insert into an empty slot. + Reclaim, // Insert into a reclaimed slot. + Assign, // Hit and assign. + Evict, // Evict a key and insert into the evicted slot. + Duplicated, // Meet duplicated keys on the fly. + Busy, // Insert failed as all slots busy. + Init, +}; + +__device__ __forceinline__ bool isInsertSuccess(InsertResult result) { + if (static_cast(result) <= + static_cast(InsertResult::Evict)) { + return true; + } + return false; +} + +// Select from double buffer. +// If i % 2 == 0, select buffer 0, else buffer 1. +__forceinline__ __device__ int same_buf(int i) { return (i & 0x01) ^ 0; } +// If i % 2 == 0, select buffer 1, else buffer 0. +__forceinline__ __device__ int diff_buf(int i) { return (i & 0x01) ^ 1; } + +template > +__forceinline__ __device__ void async_copy_bulk(T *dst, T const *src) { + static_assert(N % Stride == 0); + // dst = (ScoreType*)__cvta_generic_to_shared((void*)dst); +#pragma unroll + for (int i = 0; i < N; i += Stride) { + __pipeline_memcpy_async(dst + i, src + i, sizeof(T) * Stride); + } +} + +template && + sizeof(KeyType_) == 8>> +struct LinearBucket { + + __forceinline__ __device__ LinearBucket(uint8_t *storage, uint32_t capacity) + : storage_(storage), capacity_(capacity) {} + + __forceinline__ __device__ LinearBucket() : LinearBucket(nullptr, 0) {} + + /* + Iterator: + */ + using Iterator = uint32_t; + + template + static __forceinline__ __device__ int align(Iterator &iter) { + // iter - (iter % AlignSize) + constexpr uint32_t MASK = 0xffffffffU - (AlignSize - 1); + return iter & MASK; + } + + /* + Keys: + */ + using KeyType = KeyType_; + using AtomicKey = cuda::atomic; + + static constexpr uint64_t EmptyKey = UINT64_C(0xFFFFFFFFFFFFFFFF); + static constexpr uint64_t LockedKey = UINT64_C(0xFFFFFFFFFFFFFFFD); + static constexpr uint64_t ReclaimKey = UINT64_C(0xFFFFFFFFFFFFFFFE); + + static constexpr uint64_t ReserveKeyMask = UINT64_C(0xFFFFFFFFFFFFFFFC); + + static __device__ __forceinline__ uint64_t hash(uint64_t key) { + uint64_t k = key; + k ^= k >> 33; + k *= UINT64_C(0xff51afd7ed558ccd); + k ^= k >> 33; + k *= UINT64_C(0xc4ceb9fe1a85ec53); + k ^= k >> 33; + return static_cast(k); + } + + static __device__ __forceinline__ KeyType empty_key() { return EmptyKey; } + + static __device__ __forceinline__ KeyType reclaimed_key() { + return ReclaimKey; + } + + static __device__ __forceinline__ DigestType empty_digest() { + auto hashcode = hash(EmptyKey); + return hashcode_to_digest(hashcode); + } + + static __device__ __forceinline__ bool is_valid(uint64_t const &key) { + return (key & ReserveKeyMask) != ReserveKeyMask; + } + + __device__ __forceinline__ bool is_empty(Iterator &iter) const { + auto key_slot = reinterpret_cast(keys(iter)); + auto slot_key = key_slot->load(cuda::std::memory_order_relaxed); + return slot_key == EmptyKey; + } + + __device__ __forceinline__ bool is_locked(Iterator &iter) const { + auto key_slot = reinterpret_cast(keys(iter)); + auto slot_key = key_slot->load(cuda::std::memory_order_relaxed); + return slot_key == LockedKey; + } + + __device__ __forceinline__ bool try_lock(Iterator &iter, KeyType &key) { + auto key_slot = reinterpret_cast(keys(iter)); + return key_slot->compare_exchange_strong( + key, static_cast(LockedKey), cuda::std::memory_order_acquire, + cuda::std::memory_order_relaxed); + } + + __device__ __forceinline__ void unlock(Iterator &iter, KeyType key) { + auto key_slot = reinterpret_cast(keys(iter)); + key_slot->store(key, cuda::std::memory_order_release); + } + + /* + Digest: + */ + using DigestVector = uint32_t; // used for comparison + using DigestBuffer = uint4; // used for loading + using ComparedResult = int; + + static constexpr int VectorDim = sizeof(DigestVector) / sizeof(DigestType); + static constexpr int BufferDim = sizeof(DigestBuffer) / sizeof(DigestType); + static constexpr int NumVectorPerBuffer = + sizeof(DigestBuffer) / sizeof(DigestVector); + + struct VectorComparator { + static __device__ __forceinline__ ComparedResult compare(DigestVector lhs, + DigestVector rhs) { + // Perform a vectorized comparison by byte, + // and if they are equal, set the corresponding byte in the result to + // 0xff. + ComparedResult cmp_result = __vcmpeq4(lhs, rhs); + cmp_result &= 0x01010101; + return cmp_result; + } + + static __device__ __forceinline__ int + equal_index(ComparedResult &cmp_result) { + if (cmp_result == 0) + return -1; + // CUDA uses little endian, + // and the lowest byte in register stores in the lowest address. + uint32_t index = (__ffs(cmp_result) - 1) >> 3; + cmp_result &= (cmp_result - 1); + return index; + } + }; + + static __device__ __forceinline__ DigestType + hashcode_to_digest(uint64_t hashcode) { + return static_cast(hashcode >> 32); + } + + static __device__ __forceinline__ DigestType key_to_digest(KeyType key) { + auto hashcode = hash(key); + return hashcode_to_digest(hashcode); + } + + static __device__ __forceinline__ DigestVector + digest_to_vector(DigestType digest) { + return static_cast(__byte_perm(digest, digest, 0x0000)); + } + + static __device__ __forceinline__ void + digest_buffer_to_vector(DigestBuffer const &digest_buffer, + DigestVector digest_vec[NumVectorPerBuffer]) { + digest_vec[0] = digest_buffer.x; + digest_vec[1] = digest_buffer.y; + digest_vec[2] = digest_buffer.z; + digest_vec[3] = digest_buffer.w; + } + + /* + Scores: + */ + static constexpr uint64_t EmptyScore = UINT64_C(0); + static constexpr uint64_t MaxScore = UINT64_C(0xFFFFFFFFFFFFFFFF); + using ScoreVector = uint4; + static constexpr int NumScorePerVector = + sizeof(ScoreVector) / sizeof(ScoreType); + + /* + */ + static constexpr int KeyOffset = 0; + static constexpr int DigestOffset = KeyOffset + sizeof(KeyType); + static constexpr int ScoreOffset = DigestOffset + sizeof(DigestType); + static constexpr int BucketBytes = ScoreOffset + sizeof(ScoreType); + + static __device__ __forceinline__ uint64_t memory_usage(int size) { + return BucketBytes * size; + } + + __forceinline__ __device__ uint32_t capacity() const { return capacity_; } + + __forceinline__ __device__ KeyType *keys(const Iterator &iter) const { + return reinterpret_cast(storage_ + KeyOffset * capacity_) + iter; + } + + __forceinline__ __device__ DigestType *digests(const Iterator &iter) const { + return reinterpret_cast(storage_ + DigestOffset * capacity_) + + iter; + } + + __forceinline__ __device__ ScoreType *scores(const Iterator &iter) const { + return reinterpret_cast(storage_ + ScoreOffset * capacity_) + + iter; + } + + enum class ProbeResult : uint8_t { + Init = 0, + Existed = 1, + Empty = 2, + Exhausted = 3, + Failed = 4, + Absent = 5, + }; + /* + Let iter and step have a state, and if they have been probed, they will not be + probed again + */ + template + __forceinline__ __device__ ProbeResult probe(KeyType key, Iterator &iter, + int &step) const { + static_assert(GroupSize == 1); + if (not storage_) { + step = capacity_; + return ProbeResult::Failed; + } + + if (step == capacity_) { + return ProbeResult::Exhausted; + } + + auto hashcode = hash(key); + auto digest = hashcode_to_digest(hashcode); + auto digest_vec = digest_to_vector(digest); + + // bool early_stop = false; // used when GroupSize > 1 + if (storage_ == nullptr or capacity_ == 0) { + // early_stop = true; + return ProbeResult::Failed; + } + + if (iter < 0 or iter > capacity_) { + iter = hashcode % capacity_; + } + + constexpr int Stride = BufferDim; + iter = align(iter); + + auto empty_digest = key_to_digest(EmptyKey); + auto empty_vec = digest_to_vector(empty_digest); + + ProbeResult result = ProbeResult::Init; + + for (; step < capacity_; step += Stride) { + + auto buffer = *(reinterpret_cast(digests(iter))); + // DigestBuffer buffer; + + constexpr int Length = NumVectorPerBuffer; + + DigestVector vec[Length] = {buffer.x, buffer.y, buffer.z, buffer.w}; + // digest_buffer_to_vector(buffer, vec); + + // vec[0] = buffer.x; + // vec[1] = buffer.y; + // vec[2] = buffer.z; + // vec[3] = buffer.w; + + for (int i = 0; i < Length; i++) { + + int cmp_res = VectorComparator::compare(vec[i], digest_vec); + while (true) { + int offset = VectorComparator::equal_index(cmp_res); + if (offset < 0) + break; + + auto possible_iter = iter + i * VectorDim + offset; + + auto possible_key_slot = + reinterpret_cast(keys(possible_iter)); + + auto possible_key = + possible_key_slot->load(cuda::std::memory_order_relaxed); + + if (possible_key == key) { + iter = possible_iter; + return ProbeResult::Existed; + } + } + cmp_res = VectorComparator::compare(vec[i], empty_vec); + while (true) { + int offset = VectorComparator::equal_index(cmp_res); + if (offset < 0) + break; + + auto possible_iter = iter + i * VectorDim + offset; + + auto possible_key_slot = + reinterpret_cast(keys(possible_iter)); + + auto possible_key = + possible_key_slot->load(cuda::std::memory_order_relaxed); + + if (possible_key == EmptyKey) { + iter = possible_iter; + return ProbeResult::Empty; + } + } + } + iter = (iter + Stride) % capacity_; + } + return ProbeResult::Exhausted; + } + + template + __forceinline__ __device__ bool reduce(Iterator &dst_iter, KeyType &dst_key, + ScoreType &dst_score, + ScoreType *sm_buffers) const { + + static_assert(GroupSize == 1); + + static constexpr int BulkDim = BufferDim / 2; + static_assert(BulkDim == 4); + + static constexpr int Stride = NumScorePerVector; + + Iterator iter = 0; + int rank = threadIdx.x; + + // ScoreType* sm_buffers = + // (ScoreType*)__cvta_generic_to_shared(sm_buffers_); sm_buffers = + // reinterpret_cast(__cvta_generic_to_shared((void*)sm_buffers)); + + // asm("cvta.to.shared.u64 %0, %1;" : "=l"(sm_buffers) : "l"(sm_buffers)); + + async_copy_bulk(&sm_buffers[rank * BufferDim], + scores(iter)); + __pipeline_commit(); + + bool succeed = false; + + for (; iter < capacity_; iter += BulkDim) { + if (iter < capacity_ - BulkDim) { + async_copy_bulk( + &sm_buffers[rank * BufferDim] + diff_buf(iter / BulkDim) * BulkDim, + scores(iter) + BulkDim); + } + __pipeline_commit(); + __pipeline_wait_prior(1); + ScoreType temp_scores[Stride]; + ScoreType *src = + sm_buffers + rank * BufferDim + same_buf(iter / BulkDim) * BulkDim; +#pragma unroll + for (int k = 0; k < BulkDim; k += Stride) { + *reinterpret_cast(temp_scores) = + *reinterpret_cast(src + k); +#pragma unroll + for (int j = 0; j < Stride; j += 1) { + ScoreType temp_score = temp_scores[j]; + if (temp_score < dst_score) { + auto temp_key_slot = + reinterpret_cast(keys(iter + k + j)); + + auto temp_key = + temp_key_slot->load(cuda::std::memory_order_relaxed); + + if (temp_key != LockedKey && temp_key != EmptyKey) { + dst_iter = iter + k + j; + dst_key = temp_key; + dst_score = temp_score; + succeed = true; + } + } + } + } + } + return succeed; + } + + uint8_t *__restrict__ storage_; + uint32_t capacity_; +}; + +template struct LinearBucketTable { + using BucketType = BucketType_; + using KeyType = typename BucketType::KeyType; + + LinearBucketTable(uint8_t *storage, uint64_t num_buckets, + uint32_t bucket_capacity) + : storage_(storage), num_buckets_(num_buckets), + bucket_capacity_(bucket_capacity) {} + + static __device__ __forceinline__ uint64_t hash(uint64_t key) { + return BucketType::hash(key); + } + + __device__ __forceinline__ BucketType operator[](uint64_t idx) const { + // assert(idx < num_buckets_); + auto bucket_raw_data = + storage_ + BucketType::memory_usage(bucket_capacity_) * idx; + return BucketType(bucket_raw_data, bucket_capacity_); + } + + __device__ __forceinline__ uint64_t capacity() const { + return num_buckets_ * bucket_capacity_; + } + + __device__ __forceinline__ uint32_t bucket_capacity() const { + return bucket_capacity_; + } + + __device__ __forceinline__ BucketType get_bucket(KeyType key) const { + auto hashcode = hash(key); + auto idx = hashcode / bucket_capacity_; + auto bucket_raw_data = + storage_ + BucketType::memory_usage(bucket_capacity_) * idx; + return BucketType(bucket_raw_data, bucket_capacity_); + } + + uint8_t *__restrict__ storage_; + uint64_t num_buckets_; + uint32_t bucket_capacity_; +}; + +} // namespace dyn_emb \ No newline at end of file diff --git a/corelib/dynamicemb/src/torch_utils.h b/corelib/dynamicemb/src/torch_utils.h index 09cad0132..86cb23d14 100644 --- a/corelib/dynamicemb/src/torch_utils.h +++ b/corelib/dynamicemb/src/torch_utils.h @@ -45,6 +45,28 @@ at::ScalarType convertTypeMetaToScalarType(const caffe2::TypeMeta& typeMeta); uint64_t device_timestamp(); +inline DataType get_data_type(at::Tensor tensor) { + return scalartype_to_datatype(tensor.dtype().toScalarType()); +} + +template T *get_pointer(at::Tensor tensor) { + if (not tensor.defined()) { + throw std::invalid_argument("Tensor is undefined."); + } + return static_cast(tensor.data_ptr()); +} + +template T *get_pointer(const std::optional &tensor) { + if (not tensor.has_value()) { + return nullptr; + } + auto value = tensor.value(); + if (not value.defined()) { + throw std::invalid_argument("Tensor is undefined."); + } + return static_cast(value.data_ptr()); +} + } // namespace dyn_emb //PYTHON WRAP diff --git a/corelib/dynamicemb/src/unique_op.cu b/corelib/dynamicemb/src/unique_op.cu index 559eb552a..45f04d82c 100644 --- a/corelib/dynamicemb/src/unique_op.cu +++ b/corelib/dynamicemb/src/unique_op.cu @@ -63,21 +63,25 @@ void bind_unique_op(py::module &m) { [](dyn_emb::UniqueOpBase &self, const at::Tensor &d_key, uint64_t len, const at::Tensor &d_output_index, const at::Tensor &d_unique_key, const at::Tensor &d_output_counter, uint64_t stream = 0, - const c10::optional &offset = c10::nullopt) { + const c10::optional &offset = c10::nullopt, + const c10::optional &d_frequency_counters = c10::nullopt, + const c10::optional &d_input_frequencies = c10::nullopt) { cudaStream_t cuda_stream = reinterpret_cast(stream); - if (offset.has_value()) { - self.unique(d_key, len, d_output_index, d_unique_key, - d_output_counter, cuda_stream, offset.value()); - } else { + at::Tensor offset_tensor = offset.has_value() ? offset.value() : at::Tensor(); + at::Tensor frequency_counters_tensor = d_frequency_counters.has_value() ? d_frequency_counters.value() : at::Tensor(); + at::Tensor input_frequencies_tensor = d_input_frequencies.has_value() ? d_input_frequencies.value() : at::Tensor(); + self.unique(d_key, len, d_output_index, d_unique_key, - d_output_counter, cuda_stream); - } + d_output_counter, cuda_stream, offset_tensor, + frequency_counters_tensor, input_frequencies_tensor); }, "Unique operation.", py::arg("d_key"), py::arg("len"), py::arg("d_output_index"), py::arg("d_unique_key"), py::arg("d_output_counter"), py::arg("stream") = 0, - py::arg("offset") = c10::nullopt) + py::arg("offset") = c10::nullopt, + py::arg("d_frequency_counters") = c10::nullopt, + py::arg("d_input_frequencies") = c10::nullopt) .def( "reset_capacity", diff --git a/corelib/dynamicemb/src/unique_op.h b/corelib/dynamicemb/src/unique_op.h index f465b347a..b6273ae32 100644 --- a/corelib/dynamicemb/src/unique_op.h +++ b/corelib/dynamicemb/src/unique_op.h @@ -40,7 +40,9 @@ class UniqueOpBase { virtual void unique(const at::Tensor d_key, const uint64_t len, at::Tensor d_output_index, at::Tensor d_unique_key, at::Tensor d_output_counter, cudaStream_t stream = 0, - at::Tensor offset = at::Tensor()) = 0; + at::Tensor offset = at::Tensor(), + at::Tensor d_frequency_counters = at::Tensor(), + at::Tensor d_input_frequencies = at::Tensor()) = 0; virtual void reset_capacity(at::Tensor keys, at::Tensor vals, const size_t capacity, @@ -66,8 +68,9 @@ class HashUniqueOp : public UniqueOpBase { void unique(const at::Tensor d_key, const uint64_t len, at::Tensor d_output_index, at::Tensor d_unique_key, at::Tensor d_output_counter, cudaStream_t stream = 0, - at::Tensor offset = - at::Tensor()) override { /// TODO: dtype check in runtime. + at::Tensor offset = at::Tensor(), + at::Tensor d_frequency_counters = at::Tensor(), + at::Tensor d_input_frequencies = at::Tensor()) override { /// TODO: dtype check in runtime. if (stream == 0) { stream = at::cuda::getCurrentCUDAStream().stream(); } @@ -82,12 +85,32 @@ class HashUniqueOp : public UniqueOpBase { offset_ptr = offset.data_ptr(); } + CounterType *frequency_counters_ptr = nullptr; + if (d_frequency_counters.defined() && d_frequency_counters.numel() > 0) { + // Check if frequency counters is of the same type as CounterType + if (d_frequency_counters.scalar_type() != at::CppTypeToScalarType::value) { + throw std::runtime_error( + "Frequency counters tensor must have the same type as CounterType."); + } + frequency_counters_ptr = d_frequency_counters.data_ptr(); + } + + const CounterType *input_frequencies_ptr = nullptr; + if (d_input_frequencies.defined() && d_input_frequencies.numel() > 0) { + // Check if input frequencies is of the same type as CounterType + if (d_input_frequencies.scalar_type() != at::CppTypeToScalarType::value) { + throw std::runtime_error( + "Input frequencies tensor must have the same type as CounterType."); + } + input_frequencies_ptr = d_input_frequencies.data_ptr(); + } + this->unique_op_->unique( reinterpret_cast(d_key.data_ptr()), len, reinterpret_cast(d_output_index.data_ptr()), reinterpret_cast(d_unique_key.data_ptr()), reinterpret_cast(d_output_counter.data_ptr()), stream, - offset_ptr); + offset_ptr, frequency_counters_ptr, input_frequencies_ptr); this->unique_op_->clear(stream); } diff --git a/corelib/dynamicemb/src/unique_variable.cu b/corelib/dynamicemb/src/unique_variable.cu index f1616dd03..2fbdbd10a 100644 --- a/corelib/dynamicemb/src/unique_variable.cu +++ b/corelib/dynamicemb/src/unique_variable.cu @@ -125,13 +125,21 @@ __global__ void get_insert_kernel( const KeyType *d_key, KeyType *d_unique_key, CounterType *d_val, const size_t len, KeyType *keys, CounterType *vals, const size_t capacity, CounterType *d_global_counter, const KeyType empty_key, - const CounterType empty_val, CounterType *offset_ptr = nullptr) { + const CounterType empty_val, + CounterType *d_frequency_counters, + const CounterType *d_input_frequencies, + CounterType *offset_ptr = nullptr) { const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < len) { CounterType offset = 0; if (offset_ptr != nullptr) { offset = offset_ptr[0]; } + + // if d_input_frequencies is nullptr, set input_freq to 1 + CounterType input_freq = (d_input_frequencies != nullptr) ? + d_input_frequencies[idx] : static_cast(1); + KeyType target_key = d_key[idx]; size_t hash_index = hasher::hash(target_key) % capacity; size_t counter = 0; @@ -155,17 +163,31 @@ __global__ void get_insert_kernel( d_unique_key[result_val] = target_key; d_val[idx] = result_val + offset; target_val_pos = result_val; + + if (d_frequency_counters != nullptr) { + atomicCAS(&d_frequency_counters[result_val], 0, input_freq); + } break; } else if (target_key == old_key) { while (target_val_pos == empty_val) { }; d_val[idx] = target_val_pos + offset; + + // accumulate frequency + if (d_frequency_counters != nullptr) { + atomicAdd(&d_frequency_counters[target_val_pos], input_freq); + } break; } } else if (target_key == existing_key) { while (target_val_pos == empty_val) { }; d_val[idx] = target_val_pos + offset; + + // accumulate frequency + if (d_frequency_counters != nullptr) { + atomicAdd(&d_frequency_counters[target_val_pos], input_freq); + } break; } counter++; @@ -213,7 +235,8 @@ template ::unique( const KeyType *d_key, const uint64_t len, CounterType *d_output_index, KeyType *d_unique_key, CounterType *d_output_counter, cudaStream_t stream, - CounterType *offset_ptr) { + CounterType *offset_ptr, CounterType *d_frequency_counters, + const CounterType *d_input_frequencies) { if (len == 0) { // Set the d_output_counter to 0 @@ -225,7 +248,8 @@ void unique_op::unique( get_insert_kernel <<<(len - 1) / BLOCK_SIZE_ + 1, BLOCK_SIZE_, 0, stream>>>( d_key, d_unique_key, d_output_index, len, keys_, vals_, capacity_, - counter_, empty_key, empty_val, offset_ptr); + counter_, empty_key, empty_val, d_frequency_counters, d_input_frequencies, offset_ptr); + // replace counter_ with input d_output_counter cudaMemcpyAsync(d_output_counter, counter_, sizeof(CounterType), cudaMemcpyDeviceToDevice, stream); diff --git a/corelib/dynamicemb/src/unique_variable.h b/corelib/dynamicemb/src/unique_variable.h index 3c07dbdab..683d1dd0f 100644 --- a/corelib/dynamicemb/src/unique_variable.h +++ b/corelib/dynamicemb/src/unique_variable.h @@ -162,7 +162,9 @@ class unique_op { void unique(const KeyType *d_key, const uint64_t len, CounterType *d_output_index, KeyType *d_unique_key, CounterType *d_output_counter, cudaStream_t stream, - CounterType *offset_ptr = nullptr); + CounterType *offset_ptr = nullptr, + CounterType *d_frequency_counters = nullptr, + const CounterType *d_input_frequencies = nullptr); void reset_capacity(KeyType *keys, CounterType *vals, const size_t capacity, cudaStream_t stream); diff --git a/corelib/dynamicemb/src/utils.h b/corelib/dynamicemb/src/utils.h index da3e8b399..51a31b007 100644 --- a/corelib/dynamicemb/src/utils.h +++ b/corelib/dynamicemb/src/utils.h @@ -109,6 +109,15 @@ enum class EvictStrategy : uint32_t { exit(EXIT_FAILURE); \ } +#define DISPATCH_BOOLEAN(flag, HINT, ...) \ + if (flag) { \ + constexpr bool HINT = true; \ + __VA_ARGS__(); \ + } else { \ + constexpr bool HINT = false; \ + __VA_ARGS__(); \ + } + #define HOST_INLINE __host__ __forceinline__ #define DEVICE_INLINE __device__ __forceinline__ #define HOST_DEVICE_INLINE __host__ __device__ __forceinline__ diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables.py index 07ba45e27..a0f87462d 100644 --- a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables.py +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables.py @@ -13,16 +13,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from dynamicemb import ( DynamicEmbEvictStrategy, DynamicEmbPoolingMode, - DynamicEmbStorageConfig, + DynamicEmbTableOptions, EmbOptimType, ) from dynamicemb.batched_dynamicemb_tables import BatchedDynamicEmbeddingTables +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.ADAM, + { + "learning_rate": 0.3, + "weight_decay": 0.06, + "eps": 3e-5, + "beta1": 0.8, + "beta2": 0.888, + }, + ), + ], +) def test_embedding_optimizer(opt_type, opt_params): print( f"step in test_embedding_optimizer , opt_type = {opt_type} opt_params = {opt_params}" @@ -32,6 +49,7 @@ def test_embedding_optimizer(opt_type, opt_params): device = torch.device(f"cuda:{device_id}") dims = [128, 31, 16] + table_names = ["table0", "table1", "table2"] key_type = torch.int64 value_type = torch.float32 @@ -40,17 +58,20 @@ def test_embedding_optimizer(opt_type, opt_params): dyn_emb_table_options_list = [] for dim in dims: - dyn_emb_table_options = DynamicEmbStorageConfig( - dim=dim, init_capacity=init_capacity, max_capacity=max_capacity + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=init_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + evict_strategy=DynamicEmbEvictStrategy.LRU, ) dyn_emb_table_options_list.append(dyn_emb_table_options) bdeb = BatchedDynamicEmbeddingTables( + table_names=table_names, table_options=dyn_emb_table_options_list, - index_type=key_type, - embedding_dtype=value_type, - device_id=device_id, - evict_strategy=DynamicEmbEvictStrategy.LRU, feature_table_map=[0, 0, 1, 2], pooling_mode=DynamicEmbPoolingMode.MEAN, optimizer=opt_type, @@ -80,20 +101,98 @@ def test_embedding_optimizer(opt_type, opt_params): loss.backward() -if __name__ == "__main__": - optimizer_params = [ - { - "learning_rate": 0.3, - }, - { - "learning_rate": 0.3, - "weight_decay": 0.06, - "eps": 3e-5, - "beta1": 0.8, - "beta2": 0.888, - }, - ] - - opt_types = [EmbOptimType.SGD, EmbOptimType.ADAM] - for i in range(len(opt_types)): - test_embedding_optimizer(opt_types[i], optimizer_params[i]) +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.ADAM, + { + "learning_rate": 0.3, + "weight_decay": 0.06, + "eps": 3e-5, + "beta1": 0.8, + "beta2": 0.888, + }, + ), + ], +) +def test_train_eval(opt_type, opt_params): + print(f"step in test_train_eval , opt_type = {opt_type} opt_params = {opt_params}") + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + dims = [8, 8, 8] + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + init_capacity = 1024 + max_capacity = 2048 + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=init_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + evict_strategy=DynamicEmbEvictStrategy.LRU, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + bdebt = BatchedDynamicEmbeddingTables( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=[0, 0, 1, 2], + pooling_mode=DynamicEmbPoolingMode.NONE, + optimizer=opt_type, + use_index_dedup=True, + **opt_params, + ) + """ + feature number = 4, batch size = 2 + + f0 [0,1], [12], + f1 [64,8], [12], + f2 [15, 2], [7,105], + f3 [], [0] + """ + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], dtype=key_type, device=device + ) + offsets = torch.tensor( + [0, 2, 3, 5, 6, 8, 10, 10, 11], dtype=key_type, device=device + ) + + embs_train = bdebt(indices, offsets) + torch.cuda.synchronize() + + with torch.no_grad(): + bdebt.eval() + embs_eval = bdebt(indices, offsets) + torch.cuda.synchronize() + + # non-exist key + indices = torch.tensor([777, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device).to( + key_type + ) + offsets = torch.tensor([0, 2, 3, 5, 6, 8, 10, 10, 11], device=device).to(key_type) + embs_non_exist = bdebt(indices, offsets) + torch.cuda.synchronize() + + # train + bdebt.train() + embs_train_non_exist = bdebt(indices, offsets) + torch.cuda.synchronize() + + assert torch.equal(embs_train, embs_eval) + assert torch.equal(embs_train[1:, :], embs_non_exist[1:, :]) + assert torch.all(embs_non_exist[0, :] == 0) + assert torch.all(embs_train_non_exist[0, :] != 0) + assert torch.equal(embs_train_non_exist[1:, :], embs_non_exist[1:, :]) + + print("all check passed") diff --git a/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py new file mode 100644 index 000000000..31423ddb8 --- /dev/null +++ b/corelib/dynamicemb/test/test_batched_dynamic_embedding_tables_v2.py @@ -0,0 +1,684 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple, cast + +import pytest +import torch +from dynamicemb import ( + DynamicEmbPoolingMode, + DynamicEmbScoreStrategy, + DynamicEmbTableOptions, + EmbOptimType, +) +from dynamicemb.batched_dynamicemb_tables import BatchedDynamicEmbeddingTablesV2 +from dynamicemb.dynamicemb_config import DynamicEmbTable +from dynamicemb.key_value_table import KeyValueTable, Storage, insert_or_assign +from dynamicemb.optimizer import BaseDynamicEmbeddingOptimizerV2 +from dynamicemb_extensions import EvictStrategy +from fbgemm_gpu.split_embedding_configs import EmbOptimType as OptimType +from fbgemm_gpu.split_embedding_configs import SparseType +from fbgemm_gpu.split_table_batched_embeddings_ops_common import ( + BoundsCheckMode, + EmbeddingLocation, + PoolingMode, +) +from fbgemm_gpu.split_table_batched_embeddings_ops_training import ( + ComputeDevice, + SplitTableBatchedEmbeddingBagsCodegen, +) + +POOLING_MODE: Dict[DynamicEmbPoolingMode, PoolingMode] = { + DynamicEmbPoolingMode.NONE: PoolingMode.NONE, + DynamicEmbPoolingMode.MEAN: PoolingMode.MEAN, + DynamicEmbPoolingMode.SUM: PoolingMode.SUM, +} +OPTIM_TYPE: Dict[EmbOptimType, OptimType] = { + EmbOptimType.SGD: OptimType.EXACT_SGD, + EmbOptimType.ADAM: OptimType.ADAM, +} + + +class PyDictStorage(Storage): + def __init__( + self, + options: DynamicEmbTableOptions, + optimizer: BaseDynamicEmbeddingOptimizerV2, + ): + self.options = options + self.dict: Dict[int, torch.Tensor] = {} + self.capacity = options.max_capacity + self.optimizer = optimizer + + self._emb_dim = self.options.dim + self._emb_dtype = self.options.embedding_dtype + self._value_dim = self._emb_dim + optimizer.get_state_dim(self._emb_dim) + self._initial_optim_state = optimizer.get_initial_optim_states() + + device_idx = torch.cuda.current_device() + self.device = torch.device(f"cuda:{device_idx}") + + def find( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + h_unique_keys = unique_keys.cpu() + lookup_dim = unique_embs.size(1) + results = [] + missing_keys = [] + missing_indices = [] + missing_scores_list = [] + founds_ = [] + for i in range(h_unique_keys.size(0)): + key = h_unique_keys[i].item() + if key in self.dict: + results.append(self.dict[key][0:lookup_dim]) + founds_.append(True) + else: + missing_keys.append(key) + missing_indices.append(i) + # Collect scores for missing keys + if input_scores is not None: + missing_scores_list.append(input_scores[i].item()) + founds_.append(False) + founds_ = torch.tensor(founds_, dtype=torch.bool, device=self.device) + if len(results) > 0: + unique_embs[founds_, :] = torch.cat( + [t.unsqueeze(0) for t in results], dim=0 + ) + if founds is not None: + founds[:] = founds_ + + num_missing = torch.tensor( + [len(missing_keys)], dtype=torch.long, device=self.device + ) + missing_keys = torch.tensor( + missing_keys, dtype=unique_keys.dtype, device=self.device + ) + missing_indices = torch.tensor( + missing_indices, dtype=torch.long, device=self.device + ) + + if input_scores is not None and len(missing_scores_list) > 0: + missing_scores = torch.tensor( + missing_scores_list, dtype=input_scores.dtype, device=self.device + ) + else: + missing_scores = torch.empty(0, dtype=torch.uint64, device=self.device) + + return num_missing, missing_keys, missing_indices, missing_scores + + def find_embeddings( + self, + unique_keys: torch.Tensor, + unique_embs: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.find_impl(unique_keys, unique_embs, founds, input_scores) + + def find( + self, + unique_keys: torch.Tensor, + unique_vals: torch.Tensor, + founds: Optional[torch.Tensor] = None, + input_scores: Optional[torch.Tensor] = None, + ) -> Tuple[int, torch.Tensor, torch.Tensor, torch.Tensor]: + return self.find_impl(unique_keys, unique_vals, founds, input_scores) + + def insert( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: Optional[torch.Tensor] = None, + ) -> None: + h_keys = keys.cpu() + for i in range(h_keys.size(0)): + key = h_keys[i].item() + self.dict[key] = values[i, :].clone() + + def update( + self, keys: torch.Tensor, grads: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + raise ValueError("Can't call update of PyDictSotrage") + num_missing: torch.Tensor + missing_keys: torch.Tensor + missing_indices: torch.Tensor + return num_missing, missing_keys, missing_indices + + def enable_update(self) -> bool: + return False + + def dump( + self, + start: int, + end: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + num_dumped: torch.Tensor + dumped_keys: torch.Tensor + dumped_values: torch.Tensor + dumped_scores: torch.Tensor + return num_dumped, dumped_keys, dumped_values, dumped_scores + + def load( + self, + keys: torch.Tensor, + values: torch.Tensor, + scores: torch.Tensor, + ) -> None: + raise NotImplementedError + + def embedding_dtype( + self, + ) -> torch.dtype: + return self._emb_dtype + + def embedding_dim( + self, + ) -> int: + return self._emb_dim + + def value_dim( + self, + ) -> int: + return self._value_dim + + def init_optimizer_state( + self, + ) -> float: + return self._initial_optim_state + + +def create_split_table_batched_embedding( + table_names, + feature_table_map, + optimizer_type, + opt_params, + dims, + num_embs, + pooling_mode, + device, +): + emb = SplitTableBatchedEmbeddingBagsCodegen( + [ + ( + e, + d, + EmbeddingLocation.DEVICE, + ComputeDevice.CUDA, + ) + for (e, d) in zip(num_embs, dims) + ], + optimizer=optimizer_type, + weights_precision=SparseType.FP32, + stochastic_rounding=False, + pooling_mode=pooling_mode, + output_dtype=SparseType.FP32, + device=device, + table_names=table_names, + feature_table_map=feature_table_map, + **opt_params, + bounds_check_mode=BoundsCheckMode.FATAL, + ).cuda() + return emb + + +def init_embedding_tables(stbe, bdet): + stbe.init_embedding_weights_uniform(0, 1) + for split, table in zip(stbe.split_embedding_weights(), bdet.tables): + num_emb = split.size(0) + emb_dim = split.size(1) + indices = torch.arange(num_emb, device=split.device, dtype=torch.long) + if isinstance(table, DynamicEmbTable): + val_dim = table.optstate_dim() + emb_dim + values = torch.empty( + num_emb, val_dim, dtype=split.dtype, device=split.device + ) + values[:, :emb_dim] = split + values[:, emb_dim:val_dim] = table.get_initial_optstate() + if table.evict_strategy() != EvictStrategy.KLru: + scores = torch.empty(num_emb, device=indices.device, dtype=torch.uint64) + scores.fill_(1) + else: + scores = None + insert_or_assign(table, num_emb, indices, values, scores) + elif isinstance(table, KeyValueTable): + table = cast(KeyValueTable, table) + val_dim = table.value_dim() + assert emb_dim == table.embedding_dim() + values = torch.empty( + num_emb, val_dim, dtype=split.dtype, device=split.device + ) + values[:, :emb_dim] = split + values[:, emb_dim:val_dim] = table.init_optimizer_state() + table.set_score(1) + table.insert(indices, values) + elif isinstance(table, PyDictStorage): + pydict = cast(PyDictStorage, table) + val_dim = pydict.value_dim() + assert emb_dim == pydict.embedding_dim() + values = torch.empty( + num_emb, val_dim, dtype=split.dtype, device=split.device + ) + values[:, :emb_dim] = split + values[:, emb_dim:val_dim] = pydict.init_optimizer_state() + pydict.insert(indices, values) + else: + raise ValueError("Not support table type") + # for states_per_table in stbe.split_optimizer_states(): + # for state in states_per_table: + # pass + + +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.ADAM, + { + "learning_rate": 0.3, + "weight_decay": 0.06, + "eps": 3e-5, + "beta1": 0.8, + "beta2": 0.888, + }, + ), + ], +) +@pytest.mark.parametrize("caching", [True, False]) +@pytest.mark.parametrize("PS", [None, PyDictStorage]) +def test_forward_train_eval(opt_type, opt_params, caching, PS): + print( + f"step in test_forward_train_eval , opt_type = {opt_type} opt_params = {opt_params}" + ) + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + dims = [8, 8, 8] + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + init_capacity = 1024 + max_capacity = 2048 + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=init_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + caching=caching, + local_hbm_for_values=1024**3, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + bdebt = BatchedDynamicEmbeddingTablesV2( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=[0, 0, 1, 2], + pooling_mode=DynamicEmbPoolingMode.NONE, + optimizer=opt_type, + use_index_dedup=True, + ext_ps=PS, + **opt_params, + ) + """ + feature number = 4, batch size = 2 + + f0 [0,1], [12], + f1 [64,8], [12], + f2 [15, 2], [7,105], + f3 [], [0] + """ + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], dtype=key_type, device=device + ) + offsets = torch.tensor( + [0, 2, 3, 5, 6, 8, 10, 10, 11], dtype=key_type, device=device + ) + + embs_train = bdebt(indices, offsets) + torch.cuda.synchronize() + + with torch.no_grad(): + bdebt.eval() + embs_eval = bdebt(indices, offsets) + torch.cuda.synchronize() + + # non-exist key + indices = torch.tensor([777, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device).to( + key_type + ) + offsets = torch.tensor([0, 2, 3, 5, 6, 8, 10, 10, 11], device=device).to(key_type) + embs_non_exist = bdebt(indices, offsets) + torch.cuda.synchronize() + + # train + bdebt.train() + embs_train_non_exist = bdebt(indices, offsets) + torch.cuda.synchronize() + + assert torch.equal(embs_train, embs_eval) + assert torch.equal(embs_train[1:, :], embs_non_exist[1:, :]) + assert torch.all(embs_non_exist[0, :] == 0) + assert torch.all(embs_train_non_exist[0, :] != 0) + assert torch.equal(embs_train_non_exist[1:, :], embs_non_exist[1:, :]) + + print("all check passed") + + +""" +For torchrec's adam optimizer, it will increment the optimizer_step in every forward, + which will affect the weights update, pay attention to it or try to use `set_optimizer_step()` + to control(not verified) it. +""" + + +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.ADAM, + { + "learning_rate": 0.3, + "weight_decay": 0.06, + "eps": 3e-5, + "beta1": 0.8, + "beta2": 0.888, + }, + ), + ], +) +@pytest.mark.parametrize( + "caching, pooling_mode, dims", + [ + (True, DynamicEmbPoolingMode.NONE, [8, 8, 8]), + (False, DynamicEmbPoolingMode.NONE, [16, 16, 16]), + (False, DynamicEmbPoolingMode.SUM, [128, 32, 16]), + (False, DynamicEmbPoolingMode.MEAN, [4, 8, 16]), + ], +) +@pytest.mark.parametrize("PS", [None, PyDictStorage]) +def test_backward(opt_type, opt_params, caching, pooling_mode, dims, PS): + print(f"step in test_backward , opt_type = {opt_type} opt_params = {opt_params}") + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + max_capacity = 2048 + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=max_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + caching=caching, + local_hbm_for_values=1024**3, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + feature_table_map = [0, 0, 1, 2] + bdeb = BatchedDynamicEmbeddingTablesV2( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=feature_table_map, + pooling_mode=pooling_mode, + optimizer=opt_type, + ext_ps=PS, + **opt_params, + ) + num_embs = [max_capacity // 2 for d in dims] + stbe = create_split_table_batched_embedding( + table_names, + feature_table_map, + OPTIM_TYPE[opt_type], + opt_params, + dims, + num_embs, + POOLING_MODE[pooling_mode], + device, + ) + init_embedding_tables(stbe, bdeb) + """ + feature number = 4, batch size = 2 + + f0 [0,1], [12], + f1 [64,8], [12], + f2 [15, 2, 7], [105], + f3 [], [0] + """ + for i in range(10): + indices = torch.tensor( + [0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device + ).to(key_type) + offsets = torch.tensor([0, 2, 3, 5, 6, 9, 10, 10, 11], device=device).to( + key_type + ) + + embs_bdeb = bdeb(indices, offsets) + embs_stbe = stbe(indices, offsets) + + torch.cuda.synchronize() + with torch.no_grad(): + torch.testing.assert_close(embs_bdeb, embs_stbe, rtol=1e-06, atol=1e-06) + + loss = embs_bdeb.mean() + loss.backward() + loss_stbe = embs_stbe.mean() + loss_stbe.backward() + + torch.cuda.synchronize() + torch.testing.assert_close(loss, loss_stbe) + + print(f"Passed iteration {i}") + + +@pytest.mark.parametrize( + "opt_type,opt_params", + [ + (EmbOptimType.SGD, {"learning_rate": 0.3}), + ( + EmbOptimType.ADAM, + { + "learning_rate": 0.3, + "weight_decay": 0.06, + "eps": 3e-5, + "beta1": 0.8, + "beta2": 0.888, + }, + ), + ], +) +@pytest.mark.parametrize("PS", [None, PyDictStorage]) +def test_prefetch_flush_in_cache(opt_type, opt_params, PS): + print( + f"step in test_prefetch_flush , opt_type = {opt_type} opt_params = {opt_params}" + ) + assert torch.cuda.is_available() + device_id = 0 + device = torch.device(f"cuda:{device_id}") + + table_names = ["table0", "table1", "table2"] + key_type = torch.int64 + value_type = torch.float32 + + max_capacity = 2048 + dims = [8, 8, 8] + + dyn_emb_table_options_list = [] + for dim in dims: + dyn_emb_table_options = DynamicEmbTableOptions( + dim=dim, + init_capacity=max_capacity, + max_capacity=max_capacity, + index_type=key_type, + embedding_dtype=value_type, + device_id=device_id, + score_strategy=DynamicEmbScoreStrategy.STEP, + caching=True, + local_hbm_for_values=1024**3, + ) + dyn_emb_table_options_list.append(dyn_emb_table_options) + + feature_table_map = [0, 0, 1, 2] + bdeb = BatchedDynamicEmbeddingTablesV2( + table_names=table_names, + table_options=dyn_emb_table_options_list, + feature_table_map=feature_table_map, + pooling_mode=DynamicEmbPoolingMode.NONE, + optimizer=opt_type, + enable_prefetch=False, + ext_ps=PS, + **opt_params, + ) + bdeb.enable_prefetch = True + bdeb.set_record_cache_metrics(True) + + num_embs = [max_capacity // 2 for d in dims] + stbe = create_split_table_batched_embedding( + table_names, + feature_table_map, + OPTIM_TYPE[opt_type], + opt_params, + dims, + num_embs, + POOLING_MODE[DynamicEmbPoolingMode.NONE], + device, + ) + init_embedding_tables(stbe, bdeb) + + forward_stream = torch.cuda.Stream() + pretch_stream = torch.cuda.Stream() + + # 1. Prepare input + # Input A + """ + feature number = 4, batch size = 2 + + f0 [0, 1], [12], + f1 [64,8], [12], + f2 [15, 2], [7,105], + f3 [], [0] + """ + indicesA = torch.tensor([0, 1, 12, 64, 8, 12, 15, 2, 7, 105, 0], device=device).to( + key_type + ) + offsetsA = torch.tensor([0, 2, 3, 5, 6, 8, 10, 10, 11], device=device).to(key_type) + + # Input B + # A intersection B is not none + """ + feature number = 4, batch size = 2 + + f0 [4, 12], [55], + f1 [2, 17], [1], + f2 [], [5, 13, 105], + f3 [0, 23], [42] + """ + indicesB = torch.tensor( + [4, 12, 55, 2, 17, 1, 5, 13, 105, 0, 23, 42], device=device + ).to(key_type) + offsetsB = torch.tensor([0, 2, 3, 5, 6, 6, 9, 11, 12], device=device).to(key_type) + + # stream capture will bring a cudaMalloc. + with torch.cuda.stream(forward_stream): + indicesB + 1 + with torch.cuda.stream(pretch_stream): + indicesB + 1 + + # 2. Test prefetch works when Cache empty + with torch.cuda.stream(pretch_stream): + bdeb.prefetch(indicesA, offsetsA, forward_stream) + assert bdeb.num_prefetch_ahead == 1 + assert list(bdeb.get_score().values()) == [1] * len(dims) + + with torch.cuda.stream(forward_stream): + embs_bdeb_A = bdeb(indicesA, offsetsA) + loss_bdet_A = embs_bdeb_A.mean() + loss_bdet_A.backward() + + embs_stbe_A = stbe(indicesA, offsetsA) + loss_stbe_A = embs_stbe_A.mean() + loss_stbe_A.backward() + + with torch.no_grad(): + torch.cuda.synchronize() + torch.testing.assert_close(embs_bdeb_A, embs_stbe_A, rtol=1e-06, atol=1e-06) + torch.testing.assert_close(loss_bdet_A, loss_stbe_A, rtol=1e-06, atol=1e-06) + + for cache in bdeb.caches: + metrics = cache.cache_metrics + # cache hit_rate = 100% as we do prefetch. + assert metrics[0].item() == metrics[1].item() + + with torch.no_grad(): + bdeb.flush() + bdeb.reset_cache_states() + # bdeb.set_score({table_name:1 for table_name in table_names}) + + # 3. Test prefetch works when Cache not empty + with torch.cuda.stream(pretch_stream): + bdeb.prefetch(indicesA, offsetsA, forward_stream) + bdeb.prefetch(indicesB, offsetsB, forward_stream) + assert bdeb.num_prefetch_ahead == 2 + assert list(bdeb.get_score().values()) == [2] * len(dims) + + with torch.cuda.stream(forward_stream): + embs_bdeb_A = bdeb(indicesA, offsetsA) + loss_bdet_A = embs_bdeb_A.mean() + loss_bdet_A.backward() + embs_bdeb_B = bdeb(indicesB, offsetsB) + loss_bdet_B = embs_bdeb_B.mean() + loss_bdet_B.backward() + + embs_stbe_A = stbe(indicesA, offsetsA) + loss_stbe_A = embs_stbe_A.mean() + loss_stbe_A.backward() + embs_stbe_B = stbe(indicesB, offsetsB) + loss_stbe_B = embs_stbe_B.mean() + loss_stbe_B.backward() + + with torch.no_grad(): + torch.cuda.synchronize() + torch.testing.assert_close(embs_bdeb_A, embs_stbe_A, rtol=1e-06, atol=1e-06) + torch.testing.assert_close(loss_bdet_A, loss_stbe_A, rtol=1e-06, atol=1e-06) + torch.testing.assert_close(embs_bdeb_B, embs_stbe_B, rtol=1e-06, atol=1e-06) + torch.testing.assert_close(loss_bdet_B, loss_stbe_B, rtol=1e-06, atol=1e-06) + + for cache in bdeb.caches: + metrics = cache.cache_metrics + # cache hit_rate = 100% as we do prefetch. + assert metrics[0].item() == metrics[1].item() diff --git a/corelib/dynamicemb/test/test_optimizer.py b/corelib/dynamicemb/test/test_optimizer.py index f3bd9fcd2..da6335963 100644 --- a/corelib/dynamicemb/test/test_optimizer.py +++ b/corelib/dynamicemb/test/test_optimizer.py @@ -52,7 +52,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._hashtables: @@ -112,7 +111,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -207,7 +205,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -295,7 +292,6 @@ def update( hashtables: List[DynamicEmbTable], indices: List[torch.Tensor], grads: List[torch.Tensor], - scores: Optional[List[int]] = None, ) -> None: for ht in hashtables: if ht not in self._table_state_map.keys(): @@ -600,7 +596,9 @@ def test_optimizer( opt_for_torch.update(hashtables_for_torch, indices, grads) for i in range(num_tables): opt_for_dynamicemb[i].update( - [hashtables_for_dynamicemb[i]], [indices[i]], [grads[i]] + [hashtables_for_dynamicemb[i]], + [indices[i]], + [grads[i]], ) found_weights_for_torch = [ diff --git a/corelib/dynamicemb/test/unit_test.sh b/corelib/dynamicemb/test/unit_test.sh index 1475c66ad..f7dcfa598 100644 --- a/corelib/dynamicemb/test/unit_test.sh +++ b/corelib/dynamicemb/test/unit_test.sh @@ -1,14 +1,33 @@ set -e -TEST_FILES=( - "test/test_optimizer.py" +FWD_BWD_TEST_FILES=( + "test/unit_tests/table_operation/test_table_operation.sh" + "test/unit_tests/test_lfu_scores.sh" + "test/test_batched_dynamic_embedding_tables_v2.py" "test/test_unique_op.py" "test/unit_tests/test_sequence_embedding.sh" "test/unit_tests/test_pooled_embedding.sh" + "test/unit_tests/test_twin_module.sh" + "test/unit_tests/test_alignment.py" +) + +LOAD_DUMP_TEST_FILES=( + "test/unit_tests/test_embedding_admission.sh" "test/unit_tests/test_embedding_dump_load.sh" "test/unit_tests/incremental_dump/test_incremental_dump.sh" - "test/unit_tests/test_twin_module.sh" ) -export DYNAMICEMB_DUMP_LOAD_DEBUG=1 + +case "$1" in + fwd_bwd) + TEST_FILES=("${FWD_BWD_TEST_FILES[@]}") + ;; + load_dump) + TEST_FILES=("${LOAD_DUMP_TEST_FILES[@]}") + ;; + *) + TEST_FILES=("${FWD_BWD_TEST_FILES[@]}" "${LOAD_DUMP_TEST_FILES[@]}") + ;; +esac + # Run each test file using the appropriate command for TEST_FILE in "${TEST_FILES[@]}"; do echo "Running tests in $TEST_FILE" diff --git a/corelib/dynamicemb/test/unit_tests/incremental_dump/test_incremental_dump.sh b/corelib/dynamicemb/test/unit_tests/incremental_dump/test_incremental_dump.sh index 3c3919966..c7db637ea 100644 --- a/corelib/dynamicemb/test/unit_tests/incremental_dump/test_incremental_dump.sh +++ b/corelib/dynamicemb/test/unit_tests/incremental_dump/test_incremental_dump.sh @@ -4,4 +4,4 @@ set -e pytest test/unit_tests/incremental_dump/test_dynamicemb_extensions.py -s pytest test/unit_tests/incremental_dump/test_batched_dynamicemb_tables.py -s torchrun --nproc_per_node=1 -m pytest test/unit_tests/incremental_dump/test_distributed_dynamicemb.py -s -torchrun --nproc_per_node=4 -m pytest test/unit_tests/incremental_dump/test_distributed_dynamicemb.py -s +torchrun --nproc_per_node=2 -m pytest test/unit_tests/incremental_dump/test_distributed_dynamicemb.py -s diff --git a/corelib/dynamicemb/test/unit_tests/table_operation/test_table_dump_load.py b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_dump_load.py new file mode 100644 index 000000000..9b3305ebc --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_dump_load.py @@ -0,0 +1,301 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import shutil + +import pytest +import torch +import torch.distributed as dist +from dynamicemb.scored_hashtable import ScoreArg, ScoreSpec, get_scored_table +from dynamicemb_extensions import InsertResult, ScorePolicy +from ordered_set import OrderedSet + +score_step = 0 + + +def get_scores(score_policy, keys): + batch = keys.numel() + device = keys.device + + global score_step + + score_step += 1 + + if score_policy == ScorePolicy.ASSIGN: + return torch.empty(batch, dtype=torch.uint64, device=device).fill_(score_step) + elif score_policy == ScorePolicy.ACCUMULATE: + return torch.ones(batch, dtype=torch.uint64, device=device) + else: + return torch.zeros(batch, dtype=torch.uint64, device=device) + + +@pytest.fixture(scope="session") +def backend_session(): + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + dist.barrier() + + yield + + dist.barrier() + dist.destroy_process_group() + + +def generate_files_for_accumulate( + batch_size: int, + # rank: int, + # world_size: int, + seed: int = 42, +): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + total_keys = OrderedSet() + # select_keys = OrderedSet() + while len(total_keys) < batch_size: + x = random.randint(0, (1 << 63) - 1) + total_keys.add(x) + + # if x % world_size == rank: + # select_keys.add(x) + + keys = torch.tensor(list(total_keys), dtype=torch.int64).cuda() + scores = torch.ones_like(keys) + return keys, scores + + +@pytest.mark.parametrize("key_type", [torch.int64]) +@pytest.mark.parametrize("bucket_capacity", [128, 1024]) +@pytest.mark.parametrize("num_buckets", [8192]) +@pytest.mark.parametrize("batch_size", [128 * 4096]) +@pytest.mark.parametrize( + "score_policy", + [ScorePolicy.ACCUMULATE], +) +def test_table_load( + key_type, + bucket_capacity, + num_buckets, + batch_size, + score_policy, + backend_session, +): + print("--------------------------------------------------------") + assert torch.cuda.is_available() + device = torch.cuda.current_device() + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + device = torch.cuda.current_device() + + table = get_scored_table( + capacity=num_buckets * bucket_capacity, + bucket_capacity=bucket_capacity, + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=score_policy)], + ) + + score_args_lookup = [ + ScoreArg( + name="score1", + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + + key_file = "debug_keys" + score_file = "debug_scores" + keys, scores = generate_files_for_accumulate(batch_size) + + if rank == 0: + fkey = open(key_file, "wb") + fscore = open(score_file, "wb") + fkey.write(keys.cpu().numpy().tobytes()) + fscore.write(scores.cpu().numpy().tobytes()) + fkey.close() + fscore.close() + + dist.barrier() + + table.load(key_file, {"score1": score_file}) + + masks = keys % world_size == rank + selected_keys = keys[masks] + + assert table.size() == selected_keys.numel() + + founds = torch.empty(selected_keys.numel(), dtype=torch.bool, device=device).fill_( + False + ) + score_args_lookup[0].value = torch.zeros( + selected_keys.numel(), dtype=torch.uint64, device=device + ) + + table.lookup(selected_keys, score_args_lookup, founds) + + assert founds.sum() == selected_keys.numel() + assert torch.equal( + score_args_lookup[0].value, torch.ones_like(selected_keys).to(torch.uint64) + ) + + print( + f"Table load passed when world size={world_size} and bucket capacity={bucket_capacity})" + ) + + +@pytest.mark.parametrize("key_type", [torch.int64, torch.uint64]) +@pytest.mark.parametrize("bucket_capacity", [128]) +@pytest.mark.parametrize("num_buckets", [2047, 8192]) +@pytest.mark.parametrize("batch_size", [128, 65536]) +@pytest.mark.parametrize( + "score_policy", + [ScorePolicy.ASSIGN, ScorePolicy.ACCUMULATE, ScorePolicy.GLOBAL_TIMER], +) +def test_table_dump_load( + key_type, + num_buckets, + bucket_capacity, + batch_size, + score_policy, + backend_session, +): + assert torch.cuda.is_available() + device = torch.cuda.current_device() + local_rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + batch_size = batch_size * world_size + + table = get_scored_table( + capacity=num_buckets * bucket_capacity, + bucket_capacity=bucket_capacity, + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=score_policy)], + ) + + offset = 0 + max_step = 20 + step = 0 + while step < max_step: + keys = torch.randperm(batch_size, device=device, dtype=torch.int64) + offset + + masks = keys % world_size == local_rank + keys = keys[masks] + keys = keys.to(key_type) + batch_ = keys.numel() + + score_args = [ + ScoreArg( + name="score1", value=get_scores(score_policy, keys), is_return=True + ) + ] + + insert_results = torch.empty( + batch_, dtype=table.result_type, device=device + ).fill_(InsertResult.INIT.value) + indices = torch.zeros(batch_, dtype=table.index_type, device=device) + + table.insert(keys, score_args, indices, insert_results) + + # not assign or busy + assert ( + (insert_results == InsertResult.INSERT.value) + | (insert_results == InsertResult.EVICT.value) + ).all() + + offset += batch_size + step += 1 + + key_file = f"keys_rank{local_rank}" + score_file = f"score1_rank{local_rank}" + + shutil.rmtree(key_file, ignore_errors=True) + shutil.rmtree(score_file, ignore_errors=True) + + table.dump(key_file, {"score1": score_file}) + + load_table = get_scored_table( + capacity=num_buckets * bucket_capacity, + bucket_capacity=bucket_capacity, + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=score_policy)], + ) + + load_table.load(key_file, {"score1": score_file}) + + assert table.size() == load_table.size() + + offset = 0 + max_step = 20 + step = 0 + + num_total_keys = 0 + while step < max_step: + keys = torch.arange(0, batch_size, 1, device=device, dtype=torch.int64) + offset + + masks = keys % world_size == local_rank + keys = keys[masks] + keys = keys.to(key_type) + batch_ = keys.numel() + + score_args0 = [ + ScoreArg( + name="score1", + value=get_scores(score_policy, keys), + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + + score_args1 = [ + ScoreArg( + name="score1", + value=get_scores(score_policy, keys), + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + + founds0 = torch.empty(batch_, dtype=torch.bool, device=device) + founds1 = torch.empty(batch_, dtype=torch.bool, device=device) + + table.lookup(keys, score_args0, founds0, None) + + load_table.lookup(keys, score_args1, founds1) + + assert torch.equal(founds0, founds1) + num_total_keys += founds0.sum() + + scores0 = score_args0[0].value.to(torch.int64)[founds0] + scores1 = score_args1[0].value.to(torch.int64)[founds1] + + if table.score_specs[0].policy == ScorePolicy.GLOBAL_TIMER: + # same machine + scores_bias = scores1 - scores0 + if scores_bias.numel() > 0: + assert (scores_bias == scores_bias[0]).all() + else: + assert torch.equal(scores0, scores1) + + offset += batch_size + step += 1 + + assert num_total_keys == load_table.size() diff --git a/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.py b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.py new file mode 100644 index 000000000..02868fc31 --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.py @@ -0,0 +1,480 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from typing import List + +import pytest +import torch +import torch.distributed as dist +import torchrec +from dynamicemb.scored_hashtable import ScoreArg, ScoreSpec, get_scored_table +from dynamicemb_extensions import InsertResult, ScorePolicy, table_partition + + +@pytest.fixture +def current_device(): + assert torch.cuda.is_available() + return torch.cuda.current_device() + + +def random_indices(batch, min_index, max_index): + result = set({}) + while len(result) < batch: + result.add(random.randint(min_index, max_index)) + return result + + +def generate_sparse_feature( + feature_names: List[str], + multi_hot_sizes: List[int], + local_batch_size: int, + unique_indices_list: List[set], + use_dynamicembs: List[bool], + num_embeddings: List[int], +): + feature_num = len(feature_names) + feature_batch = feature_num * local_batch_size + + indices = [] + lengths = [] + + for i in range(feature_batch): + f = i // local_batch_size + cur_bag_size = random.randint(0, multi_hot_sizes[f]) + cur_bag = set({}) + while len(cur_bag) < cur_bag_size: + if use_dynamicembs[f]: + cur_bag.add(random.randint(0, (1 << 63) - 1)) + else: + cur_bag.add(random.randint(0, num_embeddings[f] - 1)) + + unique_indices_list[f].update(cur_bag) + indices.extend(list(cur_bag)) + lengths.append(cur_bag_size) + + return torchrec.KeyedJaggedTensor( + keys=feature_names, + values=torch.tensor(indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(lengths, dtype=torch.int64).cuda(), + ) + + +score_step = 0 + + +def get_scores(score_policy, keys): + batch = keys.numel() + device = keys.device + + global score_step + + score_step += 1 + + if score_policy == ScorePolicy.ASSIGN: + return torch.empty(batch, dtype=torch.uint64, device=device).fill_(score_step) + elif score_policy == ScorePolicy.ACCUMULATE: + return torch.ones(batch, dtype=torch.uint64, device=device) + else: + return torch.zeros(batch, dtype=torch.uint64, device=device) + + +@pytest.fixture(scope="session") +def backend_session(): + dist.init_process_group(backend="nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + int(os.environ["WORLD_SIZE"]) + torch.cuda.set_device(local_rank) + yield + # dist.barrier() + dist.destroy_process_group() + + +@pytest.mark.parametrize("key_type", [torch.int64, torch.uint64]) +@pytest.mark.parametrize("digest_type", [torch.uint8]) +@pytest.mark.parametrize("score_type", [torch.uint64]) +@pytest.mark.parametrize("bucket_capacity", [128, 1024]) +@pytest.mark.parametrize("num_buckets", [1, 13, 1024]) +def test_table_partition( + key_type, + digest_type, + score_type, + bucket_capacity, + num_buckets, +): + print("--------------------------------------------------------") + assert torch.cuda.is_available() + device = torch.cuda.current_device() + + dtypes = [key_type, digest_type, score_type] + dtypes_byte = [dtype.itemsize for dtype in dtypes] + storage = torch.empty( + sum(dtypes_byte) * bucket_capacity * num_buckets, + dtype=torch.uint8, + device=device, + ) + + keys, digests, scores = table_partition( + storage, + dtypes, + bucket_capacity, + num_buckets, + ) + + # dtype + assert keys.dtype == key_type + assert digests.dtype == digest_type + assert scores.dtype == score_type + + # size + assert keys.size() == (num_buckets, bucket_capacity) + assert digests.size() == (num_buckets, bucket_capacity) + assert scores.size() == (num_buckets, bucket_capacity) + + # stride + bucket_bytes = sum(dtypes_byte) * bucket_capacity + assert keys.stride() == (bucket_bytes // key_type.itemsize, 1) + assert digests.stride() == (bucket_bytes // digest_type.itemsize, 1) + assert scores.stride() == (bucket_bytes // score_type.itemsize, 1) + + # no overlap + ascend_keys = ( + torch.arange(0, num_buckets * bucket_capacity, dtype=torch.int64, device=device) + .view(num_buckets, bucket_capacity) + .to(key_type) + ) + zero_digests = torch.zeros( + num_buckets * bucket_capacity, dtype=digest_type, device=device + ).view(num_buckets, bucket_capacity) + descend_scores = ( + torch.arange( + num_buckets * bucket_capacity - 1, -1, -1, dtype=torch.int64, device=device + ) + .view(num_buckets, bucket_capacity) + .to(score_type) + ) + keys[:] = ascend_keys + digests[:] = zero_digests + scores[:] = descend_scores + assert torch.equal(keys, ascend_keys) + assert torch.equal(digests, zero_digests) + assert torch.equal(scores, descend_scores) + + table = get_scored_table( + capacity=num_buckets * bucket_capacity - 1, # corner case + bucket_capacity=bucket_capacity - 1, # corner case + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=ScorePolicy.CONST)], + ) + + assert table.capacity() == num_buckets * bucket_capacity + assert table.key_type == key_type + assert len(table.score_specs) == 1 + + print( + "Table partition passed: table capacity and bucket capacity rounded as expected." + ) + print("Table partition passed: sizes, strides and dtype all matched.") + print( + "Table partition passed: there was no overlap across keys, digests and scores in memory address." + ) + + +@pytest.mark.parametrize("key_type", [torch.int64, torch.uint64]) +@pytest.mark.parametrize("bucket_capacity", [128, 1024]) +@pytest.mark.parametrize("num_buckets", [13, 512]) +@pytest.mark.parametrize("batch_size", [1, 32, 128]) +@pytest.mark.parametrize( + "score_policy", + [ScorePolicy.ASSIGN, ScorePolicy.ACCUMULATE, ScorePolicy.GLOBAL_TIMER], +) +def test_table_basic( + key_type, + num_buckets, + bucket_capacity, + batch_size, + score_policy, +): + print("--------------------------------------------------------") + assert torch.cuda.is_available() + device = torch.cuda.current_device() + + table = get_scored_table( + capacity=num_buckets * bucket_capacity, + bucket_capacity=bucket_capacity, + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=score_policy)], + ) + + keys = torch.randperm(batch_size, device=device, dtype=torch.int64).to(key_type) + + score_args = [ + ScoreArg(name="score1", value=get_scores(score_policy, keys), is_return=True) + ] + score_copy_0 = score_args[0].value.clone() + insert_results = torch.empty(batch_size, dtype=table.result_type, device=device) + indices = torch.empty(batch_size, dtype=table.index_type, device=device) + + table.insert(keys, score_args, indices, insert_results) + + assert insert_results.eq(InsertResult.INSERT.value).all() + + score_args_reinsert = [ + ScoreArg(name="score1", value=get_scores(score_policy, keys), is_return=True) + ] + score_copy_1 = score_args_reinsert[0].value.clone() + insert_results = torch.zeros(batch_size, dtype=table.result_type, device=device) + indices_reinsert = torch.zeros(batch_size, dtype=table.index_type, device=device) + + table.insert(keys, score_args_reinsert, indices_reinsert, insert_results) + + assert insert_results.eq(InsertResult.ASSIGN.value).all() + assert torch.equal(indices, indices_reinsert) + + score_args_lookup = [ + ScoreArg( + name="score1", + value=get_scores(score_policy, keys), + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + founds = torch.empty(batch_size, dtype=torch.bool, device=device).fill_(False) + indices_lookup = torch.empty( + batch_size, dtype=table.index_type, device=device + ).fill_(-1) + + table.lookup(keys, score_args_lookup, founds, indices_lookup) + + assert founds.all() + assert torch.equal(indices_lookup, indices) + + if table.score_specs[0].policy == ScorePolicy.ASSIGN: + assert torch.equal(score_args_lookup[0].value, score_args_reinsert[0].value) + elif table.score_specs[0].policy == ScorePolicy.ACCUMULATE: + assert torch.equal( + score_args_lookup[0].value.to(torch.int64), + score_copy_0.to(torch.int64) + score_copy_1.to(torch.int64), + ) + else: + assert torch.equal(score_args_lookup[0].value, score_args_reinsert[0].value) + assert ( + score_args[0].value.to(torch.int64) + < score_args_reinsert[0].value.to(torch.int64) + ).all() + + table.erase(keys) + table.lookup(keys, score_args_lookup, founds, indices_lookup) + assert not founds.any() + + max_num_reclaim = keys.numel() + accum_num_reclaim = 0 + + print( + "Basic table operation(insert, lookup, erase) passed during the filling stage." + ) + + offset = batch_size + max_step = 20 + step = 1 + while table.size() < table.capacity() and step < max_step: + keys = ( + torch.randperm(bucket_capacity, device=device, dtype=torch.int64) + offset + ) + keys = keys.to(key_type) + + score_args = [ + ScoreArg( + name="score1", value=get_scores(score_policy, keys), is_return=True + ) + ] + + insert_results = torch.empty( + bucket_capacity, dtype=table.result_type, device=device + ).fill_(InsertResult.INIT.value) + indices = torch.zeros(bucket_capacity, dtype=table.index_type, device=device) + + table.insert(keys, score_args, indices, insert_results) + + num_inserted = (insert_results == InsertResult.INSERT.value).sum() + num_reclaimed = (insert_results == InsertResult.RECLAIM.value).sum() + num_eviction = (insert_results == InsertResult.EVICT.value).sum() + num_assign = (insert_results == InsertResult.ASSIGN.value).sum() + + assert keys.numel() == num_inserted + num_reclaimed + num_eviction + assert num_assign == 0 + + accum_num_reclaim += num_reclaimed + + print( + f"Table insert passed when load factor({table.load_factor():.3f}) with : insert({num_inserted}), reclaim({num_reclaimed}), evict({num_eviction})" + ) + + offset += bucket_capacity + step += 1 + + if table.size() == table.capacity(): + assert ( + accum_num_reclaim == max_num_reclaim + ), f"Occupyied({accum_num_reclaim}/{max_num_reclaim}) reclaimed slots when table is full." + + keys = torch.randperm(batch_size, device=device, dtype=torch.int64) + offset + keys = keys.to(key_type) + + score_args = [ + ScoreArg( + name="score1", value=get_scores(score_policy, keys), is_return=False + ) + ] + + insert_results = torch.empty( + batch_size, dtype=table.result_type, device=device + ).fill_(InsertResult.INIT.value) + indices = torch.zeros(batch_size, dtype=table.index_type, device=device) + + table.insert(keys, score_args, indices, insert_results) + + # only eviction + assert (insert_results == InsertResult.EVICT.value).all() + + founds.fill_(True) + table.erase(keys) + table.lookup(keys, score_args, founds, indices_lookup) + assert not founds.any() + + indices_reinsert = torch.empty( + batch_size, dtype=table.index_type, device=device + ).fill_(-1) + table.insert(keys, score_args, indices_reinsert, insert_results) + + assert (insert_results == InsertResult.RECLAIM.value).all() + + assert torch.equal( + torch.sort(indices).values, torch.sort(indices_reinsert).values + ) + + print("Table operation(insert, erase, lookup) passed when table is full.") + + +@pytest.mark.parametrize("key_type", [torch.int64]) +@pytest.mark.parametrize("bucket_capacity", [128, 1024]) +@pytest.mark.parametrize("num_buckets", [8192]) +@pytest.mark.parametrize("batch_size", [65536, 1048576]) +@pytest.mark.parametrize( + "score_policy", + [ScorePolicy.ASSIGN, ScorePolicy.ACCUMULATE, ScorePolicy.GLOBAL_TIMER], +) +def test_table_evict( + key_type, + num_buckets, + bucket_capacity, + batch_size, + score_policy, +): + print("--------------------------------------------------------") + assert torch.cuda.is_available() + device = torch.cuda.current_device() + + table = get_scored_table( + capacity=num_buckets * bucket_capacity, + bucket_capacity=bucket_capacity, + key_type=key_type, + score_specs=[ScoreSpec(name="score1", policy=score_policy)], + ) + + score_args = [ScoreArg(name="score1", is_return=True)] + score_args_lookup = [ + ScoreArg( + name="score1", + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + + offset = 0 + + while table.size() < table.capacity(): + keys = torch.randperm(batch_size, device=device, dtype=torch.int64) + offset + offset += batch_size + keys = keys.to(key_type) + + score_args[0].value = get_scores(score_policy, keys) + score_args_lookup[0].value = torch.zeros( + batch_size, dtype=torch.uint64, device=device + ) + + insert_results = torch.empty( + batch_size, dtype=table.result_type, device=device + ).fill_(InsertResult.INIT.value) + + indices = torch.zeros(batch_size, dtype=table.index_type, device=device) + + ( + num_evicted, + evicted_keys, + evicted_indices, + evicted_scores, + ) = table.insert_and_evict(keys, score_args, indices, insert_results) + evicted_scores = evicted_scores[0] + + founds = torch.empty(batch_size, dtype=torch.bool, device=device).fill_(False) + indices_lookup = torch.empty( + batch_size, dtype=table.index_type, device=device + ).fill_(-1) + + table.lookup(keys, score_args_lookup, founds, indices_lookup) + + num_existed = founds.sum() + + num_inserted = (insert_results == InsertResult.INSERT.value).sum() + num_reclaim = (insert_results == InsertResult.RECLAIM.value).sum() + num_assign = (insert_results == InsertResult.ASSIGN.value).sum() + num_inserted_by_eviction = (insert_results == InsertResult.EVICT.value).sum() + num_insert_failed = (insert_results == InsertResult.BUSY.value).sum() + + assert ( + num_reclaim == 0 + ), f"There is no erase operation, but got {num_reclaim} reclaimed slots when insert." + assert ( + num_assign == 0 + ), f"There is no duplicated keys, but got {num_assign} duplicated keys when insert." + + assert batch_size == num_inserted + num_inserted_by_eviction + num_insert_failed + assert num_existed == num_inserted + num_inserted_by_eviction + assert num_evicted == num_inserted_by_eviction + num_insert_failed + + assert torch.equal(indices, indices_lookup) + + if table.score_specs[0].policy == ScorePolicy.ASSIGN: + assert torch.equal( + score_args_lookup[0].value.to(torch.int64)[founds], + score_args[0].value.to(torch.int64)[founds], + ) + global score_step + assert ( + score_args_lookup[0].value.to(torch.int64)[founds] == score_step + ).all() + elif table.score_specs[0].policy == ScorePolicy.ACCUMULATE: + assert (score_args_lookup[0].value.to(torch.int64)[founds] == 1).all() + else: + assert torch.equal( + score_args_lookup[0].value.to(torch.int64)[founds], + score_args[0].value.to(torch.int64)[founds], + ) + + print( + f"Table insert_and_evict passed when load factor:({table.load_factor():.3f}) with: insert({num_inserted}), evict({num_inserted_by_eviction}), failed({num_insert_failed})" + ) diff --git a/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.sh b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.sh new file mode 100644 index 000000000..b60f0f9ab --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/table_operation/test_table_operation.sh @@ -0,0 +1,7 @@ +#!/bin/bash +set -e + +pytest test/unit_tests/table_operation/test_table_operation.py -s + +torchrun --nproc_per_node=1 -m pytest test/unit_tests/table_operation/test_table_dump_load.py -s +torchrun --nproc_per_node=2 -m pytest test/unit_tests/table_operation/test_table_dump_load.py -s \ No newline at end of file diff --git a/corelib/dynamicemb/test/unit_tests/test_alignment.py b/corelib/dynamicemb/test/unit_tests/test_alignment.py new file mode 100644 index 000000000..dd24234b0 --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/test_alignment.py @@ -0,0 +1,620 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Unit test: Memory stats for varying EmbeddingConfig.num_embeddings and global_hbm_for_values, +# with caching on/off. +# +# Also compares with actual DMP model config: EmbeddingConfig -> EmbeddingCollection -> apply_dmp, +# then read max_capacity / local_hbm_for_values from _dynamicemb_options and compare to theoretical. + +import math +import os +import sys +import warnings +from typing import Any, Dict, List, Tuple + +import pytest +import torch +import torch.distributed as dist +import torch.nn as nn +from dynamicemb.dump_load import find_sharded_modules, get_dynamic_emb_module + +# Run from dynamicemb package root or with PYTHONPATH including corelib/dynamicemb +from dynamicemb.dynamicemb_config import ( + DynamicEmbInitializerArgs, + DynamicEmbInitializerMode, + DynamicEmbTableOptions, + align_to_table_size, + data_type_to_dtype, + dtype_to_bytes, + get_constraint_capacity, + get_optimizer_state_dim, +) +from dynamicemb.get_planner import get_planner +from dynamicemb.shard import DynamicEmbeddingCollectionSharder +from dynamicemb.types import DEMB_TABLE_ALIGN_SIZE +from dynamicemb_extensions import OptimizerType +from fbgemm_gpu.split_embedding_configs import EmbOptimType, SparseType +from torchrec import DataType +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + +# Default world sizes for report (aligned with planner: per-rank capacity and per-rank HBM) +DEFAULT_WORLD_SIZES = [1, 8] + +# Fixed table params (aligned with planner / batched_dynamicemb_tables) +EMBEDDING_DIM = 128 +EMBEDDING_DTYPE = torch.float32 +OPTIMIZER_TYPE = OptimizerType.Adam +# Non-cache table: bucket_capacity fixed at 128; per-rank aligned capacity at least 128 +BUCKET_CAPACITY_NORMAL = 128 +# Cache mode: cache bucket_capacity=1024, minimum capacity 1024 (round up to 1 bucket if smaller) +BUCKET_CAPACITY_CACHE = 1024 + + +def _element_size() -> int: + return dtype_to_bytes(EMBEDDING_DTYPE) + + +def _optim_state_dim() -> int: + return get_optimizer_state_dim(OPTIMIZER_TYPE, EMBEDDING_DIM, EMBEDDING_DTYPE) + + +def _total_dim() -> int: + return EMBEDDING_DIM + _optim_state_dim() + + +def _byte_per_vector() -> int: + return _element_size() * _total_dim() + + +def compute_memory_stats( + num_embeddings: int, + global_hbm_for_values: int, + caching: bool, + world_size: int = 1, +) -> Tuple[int, int, int, int, int]: + """ + Compute per-rank HBM/DRAM stats for a single table (aligned with planner + batched_dynamicemb_tables). + + Planner logic: + - num_embeddings_per_rank = align_to_table_size(ceil(num_embeddings / world_size), alignment=bucket_capacity) + - If num_aligned_embedding_per_rank < bucket_capacity(128), use bucket_capacity + - local_hbm_for_values = ceil(global_hbm_for_values / world_size) + - When caching: cache bucket=1024, min capacity 1024 (get_constraint_capacity rounds up to 1 bucket) + + Returns + ------- + total_bytes_per_rank, hbm_bytes_per_rank, dram_bytes_per_rank, aligned_capacity_per_rank, total_bytes_all_ranks + """ + num_per_rank = math.ceil(num_embeddings / world_size) + aligned_capacity_per_rank = align_to_table_size( + num_per_rank, alignment=BUCKET_CAPACITY_NORMAL + ) + aligned_capacity_per_rank = max(aligned_capacity_per_rank, BUCKET_CAPACITY_NORMAL) + total_memory_per_rank = aligned_capacity_per_rank * _byte_per_vector() + + local_hbm = math.ceil(global_hbm_for_values / world_size) + local_hbm = min(local_hbm, total_memory_per_rank) + + if caching: + bucket_cap = BUCKET_CAPACITY_CACHE # 1024 + cache_capacity = get_constraint_capacity( + local_hbm, + EMBEDDING_DTYPE, + EMBEDDING_DIM, + OPTIMIZER_TYPE, + bucket_cap, + ) + # Cache min capacity 1024 (get_constraint_capacity already rounds up to 1 bucket if needed) + cache_capacity = max(cache_capacity, BUCKET_CAPACITY_CACHE) + hbm_bytes_per_rank = cache_capacity * _byte_per_vector() + # Storage holds full table shard for this rank, all in DRAM + dram_bytes_per_rank = total_memory_per_rank + else: + hbm_bytes_per_rank = local_hbm + dram_bytes_per_rank = total_memory_per_rank - local_hbm + + total_bytes_all_ranks = total_memory_per_rank * world_size + return ( + total_memory_per_rank, + hbm_bytes_per_rank, + dram_bytes_per_rank, + aligned_capacity_per_rank, + total_bytes_all_ranks, + ) + + +def _mb(x: int) -> float: + return x / (1024 * 1024) + + +def _format_mb(x: int) -> str: + return f"{_mb(x):.2f} MB" + + +def run_alignment_memory_report( + num_embeddings_list: List[int], + global_hbm_modes: List[str], + world_sizes: List[int], + include_caching: bool = True, +) -> List[dict]: + """ + Build memory report for (num_embeddings, global_hbm_for_values, caching, world_size) combinations. + global_hbm_modes: ["0", "half", "full"] = HBM budget 0 / half of total need / full (all global). + total/hbm/dram are per-rank values. + """ + rows = [] + for num_emb in num_embeddings_list: + for world_size in world_sizes: + num_per_rank = math.ceil(num_emb / world_size) + aligned_per_rank = align_to_table_size( + num_per_rank, alignment=BUCKET_CAPACITY_NORMAL + ) + aligned_per_rank = max(aligned_per_rank, BUCKET_CAPACITY_NORMAL) + total_mem_per_rank = aligned_per_rank * _byte_per_vector() + # Global HBM budget (for half/full): based on total table memory across all ranks + total_mem_global = total_mem_per_rank * world_size + + for gmode in global_hbm_modes: + if gmode == "0": + global_hbm = 0 + elif gmode == "half": + global_hbm = total_mem_global // 2 + elif gmode == "full": + global_hbm = total_mem_global + else: + raise ValueError(f"Unknown global_hbm mode: {gmode}") + + for caching in [False, True] if include_caching else [False]: + ( + total_bytes, + hbm_bytes, + dram_bytes, + aligned_cap, + total_all_ranks, + ) = compute_memory_stats(num_emb, global_hbm, caching, world_size) + rows.append( + { + "num_embeddings": num_emb, + "world_size": world_size, + "aligned_capacity_per_rank": aligned_cap, + "global_hbm_mode": gmode, + "global_hbm_bytes": global_hbm, + "caching": caching, + "total_bytes": total_bytes, + "hbm_bytes": hbm_bytes, + "dram_bytes": dram_bytes, + "total_bytes_all_ranks": total_all_ranks, + } + ) + return rows + + +def print_report(rows: List[dict], show_all_ranks: bool = False) -> None: + """Print memory consumption table. total/HBM/DRAM are per rank; optionally show all_ranks column.""" + sep = " | " + headers = [ + "num_emb", + "W", + "aligned/r", + "global_hbm", + "caching", + "total(MB)/r", + "HBM(MB)/r", + "DRAM(MB)/r", + ] + if show_all_ranks: + headers.append("total(MB)*W") + col_widths = [10, 4, 10, 8, 8, 12, 12, 12] + if show_all_ranks: + col_widths.append(12) + line = sep.join(h.ljust(col_widths[i]) for i, h in enumerate(headers)) + print(line) + print("-" * len(line)) + + for r in rows: + total_mb = _format_mb(r["total_bytes"]) + hbm_mb = _format_mb(r["hbm_bytes"]) + dram_mb = _format_mb(r["dram_bytes"]) + global_hbm_str = r["global_hbm_mode"] + row = [ + str(r["num_embeddings"]).ljust(col_widths[0]), + str(r["world_size"]).ljust(col_widths[1]), + str(r["aligned_capacity_per_rank"]).ljust(col_widths[2]), + global_hbm_str.ljust(col_widths[3]), + str(r["caching"]).ljust(col_widths[4]), + total_mb.ljust(col_widths[5]), + hbm_mb.ljust(col_widths[6]), + dram_mb.ljust(col_widths[7]), + ] + if show_all_ranks: + row.append(_format_mb(r["total_bytes_all_ranks"]).ljust(col_widths[8])) + print(sep.join(row)) + print() + + +# --------------- Compare with actual DMP model config --------------- + + +class _SingleTableTestModel(nn.Module): + """Single EmbeddingCollection, single table; used to read actual config after apply_dmp.""" + + def __init__(self, embedding_module: EmbeddingCollection): + super().__init__() + self.embedding_modules = nn.ModuleList([embedding_module]) + + def forward(self, kjt: KeyedJaggedTensor) -> torch.Tensor: + embeddings_dict = [emb(kjt).wait() for emb in self.embedding_modules] + out = [] + for d in embeddings_dict: + for v in d.values(): + out.append(v.values()) + return torch.cat(out, dim=0) + + +def _apply_dmp_with_global_hbm( + num_embeddings: int, + embedding_dim: int, + global_hbm_for_values: int, + caching: bool, + device: torch.device, + optimizer_kwargs: Dict[str, Any], +) -> nn.Module: + """ + Create single-table EmbeddingConfig -> EmbeddingCollection -> apply_dmp; return DMP model. + global_hbm_for_values is global HBM budget in bytes; planner splits by world_size per rank. + """ + from dynamicemb import DynamicEmbScoreStrategy + + name = "emb_0" + eb_config = EmbeddingConfig( + name=name, + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + feature_names=["f0"], + data_type=DataType.FP32, + ) + ebc = EmbeddingCollection( + device=torch.device("meta"), + tables=[eb_config], + ) + model = _SingleTableTestModel(ebc) + + bucket_capacity = BUCKET_CAPACITY_CACHE if caching else BUCKET_CAPACITY_NORMAL + emb_num_aligned = align_to_table_size(num_embeddings, alignment=bucket_capacity) + torch_dtype = data_type_to_dtype(DataType.FP32) + opt_state_dim = get_optimizer_state_dim( + OptimizerType.Adam, embedding_dim, torch_dtype + ) + total_hbm_need = ( + (embedding_dim + opt_state_dim) * dtype_to_bytes(torch_dtype) * emb_num_aligned + ) + # If global_hbm not set, use full need + if global_hbm_for_values <= 0: + global_hbm_for_values = total_hbm_need + + dynamicemb_options_dict = { + name: DynamicEmbTableOptions( + global_hbm_for_values=global_hbm_for_values, + score_strategy=DynamicEmbScoreStrategy.TIMESTAMP, + initializer_args=DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.CONSTANT, + value=0.1, + ), + bucket_capacity=bucket_capacity, + max_capacity=emb_num_aligned, + caching=caching, + ) + } + eb_configs = list(ebc.embedding_configs()) + planner = get_planner( + eb_configs, + set(), + dynamicemb_options_dict, + device, + ) + fused_params = { + "output_dtype": SparseType.FP32, + **optimizer_kwargs, + } + sharder = DynamicEmbeddingCollectionSharder( + fused_params=fused_params, + use_index_dedup=False, + ) + plan = planner.collective_plan(model, [sharder], dist.GroupMember.WORLD) + dmp = DistributedModelParallel( + module=model, + device=device, + sharders=[sharder], + plan=plan, + ) + return dmp + + +def get_actual_table_options_from_model(model: nn.Module) -> List[Dict[str, Any]]: + """ + Collect actual config (max_capacity, local_hbm_for_values, caching) for all + BatchedDynamicEmbeddingTablesV2 in the model after apply_dmp. + Uses find_sharded_modules to get ShardedEmbeddingCollection, then get_dynamic_emb_module on it. + """ + result = [] + for _path, _name, sharded_module in find_sharded_modules(model): + emb_modules = get_dynamic_emb_module(sharded_module) + for mod in emb_modules: + for opt in mod._dynamicemb_options: + result.append( + { + "max_capacity": opt.max_capacity, + "local_hbm_for_values": opt.local_hbm_for_values, + "caching": opt.caching, + } + ) + return result + + +def _compare_actual_vs_theoretical( + num_embeddings: int, + global_hbm_for_values: int, + caching: bool, + world_size: int, +) -> Tuple[bool, str]: + """ + Build DMP model, read actual config, compare with compute_memory_stats theoretical values. + When bucket floor applies, planner may set local_hbm_for_values above ceil(global_hbm/W); + we compare effective HBM = min(actual local_hbm, total_memory) to theory. + Returns (match_ok, message). + """ + ( + _, + hbm_per_rank, + _, + aligned_cap_expected, + _, + ) = compute_memory_stats(num_embeddings, global_hbm_for_values, caching, world_size) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + dmp = _apply_dmp_with_global_hbm( + num_embeddings=num_embeddings, + embedding_dim=EMBEDDING_DIM, + global_hbm_for_values=global_hbm_for_values, + caching=caching, + device=device, + optimizer_kwargs={"optimizer": EmbOptimType.ADAM, "lr": 1e-3}, + ) + actual_list = get_actual_table_options_from_model(dmp) + if not actual_list: + return False, "get_dynamic_emb_module returned no table options" + actual = actual_list[0] + cap_ok = actual["max_capacity"] == aligned_cap_expected + # Non-cache: effective HBM is min(actual, total); planner may set actual > theory when bucket floor adds capacity. + # Cache: actual HBM is cache size, may be >= theory due to bucket rounding. + total_per_rank = aligned_cap_expected * _byte_per_vector() + if not caching: + effective_actual_hbm = min(actual["local_hbm_for_values"], total_per_rank) + hbm_ok = effective_actual_hbm == hbm_per_rank + else: + hbm_ok = actual["local_hbm_for_values"] >= 0 + caching_ok = actual["caching"] == caching + msg = ( + f"num_emb={num_embeddings} W={world_size} caching={caching} " + f"expected_cap={aligned_cap_expected} actual_cap={actual['max_capacity']} " + f"expected_hbm={hbm_per_rank} actual_hbm={actual['local_hbm_for_values']}" + ) + return cap_ok and caching_ok and hbm_ok, msg + + +class TestAlignmentMemoryStats: + """Tests for memory stats under varying num_embeddings / global_hbm_for_values / caching / world_size.""" + + @pytest.fixture + def num_embeddings_list(self) -> List[int]: + # Cover alignment boundaries: below 16, equal 16, non-multiple of 16, larger + return [10, 16, 17, 32, 100, 1000, 10000] + + @pytest.fixture + def global_hbm_modes(self) -> List[str]: + return ["0", "half", "full"] + + @pytest.fixture + def world_sizes(self) -> List[int]: + return [1, 8] + + def test_align_to_table_size_default_alignment(self): + """With default alignment, result is a multiple of DEMB_TABLE_ALIGN_SIZE.""" + for n in [0, 1, 15, 16, 17, 32, 100]: + aligned = align_to_table_size(n) + assert ( + aligned % DEMB_TABLE_ALIGN_SIZE == 0 + ), f"align_to_table_size({n}) = {aligned} not multiple of {DEMB_TABLE_ALIGN_SIZE}" + if n > 0: + assert aligned >= n, f"align_to_table_size({n}) = {aligned} < {n}" + + def test_align_to_table_size_bucket_capacity(self): + """With bucket_capacity alignment, result is a multiple of bucket_capacity.""" + for bucket_cap in [128, 1024]: + for n in [0, 1, 127, 128, 129, 1000, 3125000]: + aligned = align_to_table_size(n, alignment=bucket_cap) + assert ( + aligned % bucket_cap == 0 + ), f"align_to_table_size({n}, {bucket_cap}) = {aligned} not multiple of {bucket_cap}" + if n > 0: + assert ( + aligned >= n + ), f"align_to_table_size({n}, {bucket_cap}) = {aligned} < {n}" + + def test_memory_stats_total_consistent( + self, num_embeddings_list, global_hbm_modes, world_sizes + ): + """Per-rank total matches aligned_capacity_per_rank; non-caching: HBM+DRAM=total; caching: DRAM=total.""" + rows = run_alignment_memory_report( + num_embeddings_list, global_hbm_modes, world_sizes, include_caching=True + ) + for r in rows: + expected_total = r["aligned_capacity_per_rank"] * _byte_per_vector() + assert ( + r["total_bytes"] == expected_total + ), f"total_bytes mismatch: {r['total_bytes']} vs {expected_total}" + if not r["caching"]: + assert ( + r["hbm_bytes"] + r["dram_bytes"] == r["total_bytes"] + ), f"non-caching: hbm + dram != total: {r}" + else: + assert ( + r["dram_bytes"] == r["total_bytes"] + ), f"caching: dram should equal total (storage): {r}" + + def test_caching_increases_hbm_when_global_hbm_nonzero( + self, num_embeddings_list, global_hbm_modes, world_sizes + ): + """When global_hbm is non-zero, HBM under caching is determined by cache capacity.""" + for num_emb in num_embeddings_list: + for world_size in world_sizes: + num_per_rank = math.ceil(num_emb / world_size) + aligned = align_to_table_size( + num_per_rank, alignment=BUCKET_CAPACITY_NORMAL + ) + aligned = max(aligned, BUCKET_CAPACITY_NORMAL) + total_mem_global = aligned * _byte_per_vector() * world_size + for gmode in ["half", "full"]: + global_hbm = ( + total_mem_global // 2 if gmode == "half" else total_mem_global + ) + _, hbm_no_cache, _, _, _ = compute_memory_stats( + num_emb, global_hbm, caching=False, world_size=world_size + ) + _, hbm_cache, _, _, _ = compute_memory_stats( + num_emb, global_hbm, caching=True, world_size=world_size + ) + assert hbm_cache >= 0 and hbm_no_cache >= 0 + + def test_alignment_memory_report_runs( + self, num_embeddings_list, global_hbm_modes, world_sizes + ): + """Run full report and assert every row has valid values.""" + rows = run_alignment_memory_report( + num_embeddings_list, global_hbm_modes, world_sizes, include_caching=True + ) + assert len(rows) > 0 + for r in rows: + assert r["total_bytes"] > 0 + assert r["hbm_bytes"] >= 0 and r["dram_bytes"] >= 0 + num_per_rank = math.ceil(r["num_embeddings"] / r["world_size"]) + assert r["aligned_capacity_per_rank"] >= align_to_table_size( + num_per_rank, alignment=BUCKET_CAPACITY_NORMAL + ) + + def test_multi_rank_reduces_per_rank_memory(self): + """With world_size=8, per-rank total_bytes should be less than full-table size with world_size=1.""" + num_emb = 10000 + total_ws1 = ( + align_to_table_size(num_emb, alignment=BUCKET_CAPACITY_NORMAL) + * _byte_per_vector() + ) + total_ws8_per_rank, _, _, _, _ = compute_memory_stats( + num_emb, global_hbm_for_values=0, caching=False, world_size=8 + ) + assert total_ws8_per_rank < total_ws1 + + def test_num_aligned_embedding_per_rank_bucket_floor(self): + """num_aligned_embedding_per_rank is bounded by bucket_capacity.""" + min_cap = BUCKET_CAPACITY_NORMAL + assert min_cap == 128 + for num_emb in [1, 10, 17, 50]: + for world_size in [1, 8]: + _, _, _, aligned_cap, _ = compute_memory_stats( + num_emb, + global_hbm_for_values=0, + caching=False, + world_size=world_size, + ) + assert ( + aligned_cap >= min_cap + ), f"num_emb={num_emb} world_size={world_size} aligned_cap={aligned_cap} < {min_cap}" + + def test_cache_min_capacity_1024(self): + """When caching is on, cache min capacity is 1024 (round up to 1 bucket if smaller).""" + for num_emb in [10, 100, 1000]: + for world_size in [1, 8]: + _, hbm_bytes, _, _, _ = compute_memory_stats( + num_emb, + global_hbm_for_values=0, + caching=True, + world_size=world_size, + ) + cache_capacity_rows = hbm_bytes // _byte_per_vector() + assert ( + cache_capacity_rows >= BUCKET_CAPACITY_CACHE + ), f"num_emb={num_emb} W={world_size} cache_capacity_rows={cache_capacity_rows} < 1024" + + @pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA required for DMP model creation", + ) + def test_actual_capacity_matches_theoretical(self): + """ + Compare with actual DMP model config: create EmbeddingConfig -> EmbeddingCollection -> + DynamicEmbTableOptions -> apply_dmp to get model, read actual max_capacity / local_hbm_for_values + from model and assert they match compute_memory_stats theoretical values. + Run with torchrun to init dist and perform comparison. + """ + if not dist.is_initialized(): + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + dist.init_process_group( + backend="gloo" if not torch.cuda.is_available() else "nccl", + init_method="env://", + ) + else: + pytest.skip( + "Distributed not initialized; run with torchrun to compare actual vs theoretical." + ) + if not dist.is_initialized(): + pytest.skip("Failed to init process group") + world_size = dist.get_world_size() + num_embeddings = 1000 + # Global HBM = full (same as "full" in theoretical report; use same bucket floor as compute_memory_stats) + aligned_per_rank = align_to_table_size( + math.ceil(num_embeddings / world_size), + alignment=BUCKET_CAPACITY_NORMAL, + ) + aligned_per_rank = max(aligned_per_rank, BUCKET_CAPACITY_NORMAL) + total_mem_global = aligned_per_rank * _byte_per_vector() * world_size + global_hbm = total_mem_global + ok, msg = _compare_actual_vs_theoretical( + num_embeddings=num_embeddings, + global_hbm_for_values=global_hbm, + caching=False, + world_size=world_size, + ) + assert ok, msg + + +def main(): + """CLI entry: print memory stats for varying num_embeddings / global_hbm_for_values / caching / world_size.""" + num_embeddings_list = [10, 16, 17, 32, 100, 1000, 10000] + global_hbm_modes = ["0", "half", "full"] + world_sizes = DEFAULT_WORLD_SIZES # e.g. [1, 8] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + + print("Alignment & HBM memory report (per rank; dim=128, Adam)") + print("DEMB_TABLE_ALIGN_SIZE =", DEMB_TABLE_ALIGN_SIZE) + print( + "W = world_size; aligned/r = aligned capacity per rank; global_hbm = global budget" + ) + print("total/HBM/DRAM = per rank (total(MB)*W = all ranks total table memory)") + print() + + rows = run_alignment_memory_report( + num_embeddings_list, global_hbm_modes, world_sizes, include_caching=True + ) + print_report(rows, show_all_ranks=True) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py new file mode 100644 index 000000000..6c1cd8748 --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.py @@ -0,0 +1,371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import os +import random +from collections import defaultdict +from typing import Dict, List, Optional, Set, Tuple + +import click +import torch +import torch.distributed as dist +import torch.nn as nn +from dynamicemb.dump_load import find_sharded_modules, get_dynamic_emb_module +from dynamicemb.embedding_admission import FrequencyAdmissionStrategy +from dynamicemb.key_value_table import batched_export_keys_values +from dynamicemb.types import DynamicEmbInitializerArgs + +# from dynamicemb.admission_strategy import FrequencyAdmissionStrategy +from test_embedding_dump_load import ( + create_model, + get_optimizer_kwargs, + get_score_strategy, + idx_to_name, +) +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def generate_deterministic_sparse_features_with_frequency_tracking( + num_embedding_collections: int, + num_embeddings: List[int], + multi_hot_sizes: List[int], + rank: int, + world_size: int, + batch_size: int, + num_iterations: int, + seed: int = 42, + caching: bool = False, +) -> Tuple[List[KeyedJaggedTensor], Dict[str, Dict[int, int]]]: + """ + Generate deterministic sparse features and track frequency for each embedding table. + + Args: + caching: If True, generate more unique keys to trigger cache eviction + + Returns: + kjts: List of KeyedJaggedTensor for each iteration + table_frequency_counters: Dict mapping table_name -> {key: frequency} + """ + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + batch_size_per_rank = batch_size // world_size + kjts = [] + table_frequency_counters = {} + + # Initialize frequency counters for each table + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + _, embedding_name = idx_to_name(embedding_collection_id, embedding_id) + table_frequency_counters[embedding_name] = defaultdict(int) + + for iteration in range(num_iterations): + cur_indices = [] + cur_lengths = [] + keys = [] + + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + feature_name, embedding_name = idx_to_name( + embedding_collection_id, embedding_id + ) + + for sample_id in range(batch_size): + hotness = random.randint( + 1, multi_hot_sizes[embedding_collection_id] + ) + if caching: + # In caching mode, use wider range to trigger eviction + max_key = int(num_embedding * 0.8) - 1 + else: + # In storage-only mode, limit to 100 keys for more duplicates + max_key = min(num_embedding - 1, 100) + indices = [random.randint(0, max_key) for _ in range(hotness)] + # Track frequency for all generated indices + for idx in indices: + table_frequency_counters[embedding_name][idx] += 1 + + if sample_id // batch_size_per_rank == rank: + cur_indices.extend(indices) + cur_lengths.append(hotness) + + keys.append(feature_name) + + kjts.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor(cur_indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(cur_lengths, dtype=torch.int64).cuda(), + ) + ) + + # Convert defaultdicts to regular dicts + for table_name in table_frequency_counters: + table_frequency_counters[table_name] = dict( + table_frequency_counters[table_name] + ) + + return kjts, table_frequency_counters + + +def get_table_keys( + model: nn.Module, + table_names: Optional[Dict[str, List[str]]] = None, + pg: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Set[int]]: + """ + Get keys from dynamic embedding tables directly (without disk I/O). + + Returns: + Dict[table_name, Set[key]]: Keys stored in each table + """ + + if torch.cuda.is_available(): + torch.cuda.synchronize() + dist.barrier(group=pg, device_ids=[torch.cuda.current_device()]) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + batch_size = 65536 + + all_table_keys = {} + + for _, collection_name, sharded_module in find_sharded_modules(model, ""): + dynamic_emb_modules = get_dynamic_emb_module(sharded_module) + + for dynamic_emb_module in dynamic_emb_modules: + dynamic_emb_module.flush() + + for table_name, table in zip( + dynamic_emb_module.table_names, dynamic_emb_module.tables + ): + if table_names is not None and table_name not in set( + table_names[collection_name] + ): + continue + + table_keys = set() + + for keys, _, _, _ in batched_export_keys_values( + table.table, device, batch_size + ): + for key in keys: + table_keys.add(int(key)) + + all_table_keys[table_name] = table_keys + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + dist.barrier(group=pg, device_ids=[torch.cuda.current_device()]) + + return all_table_keys + + +def validate_admission_keys( + expected_frequencies: Dict[str, Dict[int, int]], + actual_keys: Dict[str, Set[int]], + threshold: int, +): + """ + Validate that only keys with frequency >= threshold are stored in tables. + + In multi-GPU scenarios, keys are sharded across GPUs. Each rank only validates + keys that are actually present in its local tables (i.e., keys that were sharded to it). + This is similar to how LFU score validation works. + + Args: + expected_frequencies: Dict mapping table_name -> {key: frequency} + actual_keys: Dict mapping table_name -> Set[key] + threshold: Admission threshold + """ + for table_name in expected_frequencies: + if table_name not in actual_keys: + raise AssertionError(f"Table {table_name} missing from actual keys") + + expected = expected_frequencies[table_name] + actual = actual_keys[table_name] + + for key in actual: + if key not in expected: + raise AssertionError( + f"Table {table_name}, Key {key}: " + f"Found in table but not in expected frequencies (unexpected key)" + ) + + if expected[key] < threshold: + raise AssertionError( + f"Table {table_name}, Key {key}: " + f"Admitted with frequency={expected[key]} < threshold={threshold} " + f"(should have been rejected)" + ) + + # Count expected admitted/rejected for reporting + expected_admitted_count = sum( + 1 for freq in expected.values() if freq >= threshold + ) + expected_rejected_count = sum( + 1 for freq in expected.values() if freq < threshold + ) + actual_admitted_count = len(actual) + + print( + f"✓ Table {table_name}: " + f"{actual_admitted_count} keys admitted on this rank " + f"(global: {expected_admitted_count} expected admitted, " + f"{expected_rejected_count} expected rejected), " + f"threshold={threshold}" + ) + + +@click.command() +@click.option("--num-embedding-collections", type=int, default=1) +@click.option("--num-embeddings", type=str, default="1000") +@click.option("--multi-hot-sizes", type=str, default="3") +@click.option("--embedding-dim", type=int, default=16) +@click.option( + "--optimizer-type", + type=click.Choice(["sgd", "adam", "adagrad", "rowwise_adagrad"]), + default="sgd", +) +@click.option("--batch-size", type=int, default=16) +@click.option("--num-iterations", type=int, default=3) +@click.option("--threshold", type=int, default=5, help="Admission frequency threshold") +@click.option("--caching", is_flag=True, help="Enable cache + storage architecture") +@click.option( + "--cache-capacity-ratio", + type=float, + default=0.5, + help="Cache capacity as ratio of storage capacity (only used when --caching is enabled)", +) +@click.option( + "--score-strategy", + type=click.Choice(["timestamp", "lfu", "step"]), + default="timestamp", + help="Score strategy", +) +def test_admission_strategy_validation( + num_embedding_collections: int, + num_embeddings: str, + multi_hot_sizes: str, + embedding_dim: int, + optimizer_type: str, + batch_size: int, + num_iterations: int, + threshold: int, + caching: bool, + cache_capacity_ratio: float, + score_strategy: str, +): + """Test admission strategy correctness by comparing with naive frequency counting. + + This test validates that only keys with frequency >= threshold are stored in tables. + It supports two modes: + - Storage-only (default): Tests admission in storage directly + - Cache + Storage (--caching): Tests admission through cache to storage + """ + + num_embeddings = [int(v) for v in num_embeddings.split(",")] + multi_hot_sizes = [int(v) for v in multi_hot_sizes.split(",")] + use_index_dedup = True + + if not caching: + for num_embedding, multi_hot_size in zip(num_embeddings, multi_hot_sizes): + if batch_size * num_iterations * multi_hot_size > num_embedding: + raise ValueError( + "batch_size * num_iterations * multi_hot_size > num_embedding, " + "this may lead to eviction of dynamicemb and cause test fail" + ) + + print(f"Configuration:") + print(f" - Embedding collections: {num_embedding_collections}") + print(f" - Num embeddings: {num_embeddings}") + print(f" - Multi-hot sizes: {multi_hot_sizes}") + print(f" - Embedding dim: {embedding_dim}") + print(f" - Optimizer: {optimizer_type}") + print(f" - Batch size: {batch_size}") + print(f" - Iterations: {num_iterations}") + print(f" - Admission threshold: {threshold}") + print(f" - Use index dedup: {use_index_dedup}") + print(f" - Score strategy: {score_strategy}") + if caching: + print(f" - Caching: ENABLED ✓") + print(f" - Cache capacity ratio: {cache_capacity_ratio}") + else: + print(f" - Caching: DISABLED") + + # Create admission strategy + admission_strategy = FrequencyAdmissionStrategy( + threshold=threshold, + initializer_args=DynamicEmbInitializerArgs( + value=0.0, + ), + ) + + # Create model with admission strategy + optimizer_kwargs = get_optimizer_kwargs(optimizer_type) + model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=get_score_strategy( + score_strategy + ), # Use timestamp for admission + use_index_dedup=use_index_dedup, + caching=caching, + cache_capacity_ratio=cache_capacity_ratio if caching else 0.1, + admit_strategy=admission_strategy, # Pass admission strategy + ) + + # Generate features with frequency tracking + ( + kjts, + expected_frequencies, + ) = generate_deterministic_sparse_features_with_frequency_tracking( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + multi_hot_sizes=multi_hot_sizes, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + batch_size=batch_size, + num_iterations=num_iterations, + caching=caching, + ) + + # Run forward passes to trigger admission logic + if caching: + print( + f"\nRunning {num_iterations} iterations with cache and admission enabled..." + ) + else: + print(f"\nRunning {num_iterations} iterations with admission enabled...") + + for iteration, kjt in enumerate(kjts): + ret = model(kjt) + torch.cuda.synchronize() + loss = ret.sum() * dist.get_world_size() + loss.backward() + torch.cuda.synchronize() + + # Extract actual keys stored in tables + actual_keys = get_table_keys(model) + + torch.cuda.synchronize() + + # Validate admission logic + print(f"\nValidating admission with threshold={threshold}...") + validate_admission_keys(expected_frequencies, actual_keys, threshold) + + print(f"\n✓ Admission strategy test passed!") + + +if __name__ == "__main__": + LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(LOCAL_RANK) + + dist.init_process_group(backend="nccl") + test_admission_strategy_validation() + dist.barrier() + dist.destroy_process_group() diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_admission.sh b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.sh new file mode 100644 index 000000000..6c85fc111 --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_admission.sh @@ -0,0 +1,138 @@ +#!/bin/bash +set -e + +# Test configurations +NUM_EMBEDDING_COLLECTIONS=2 +NUM_EMBEDDINGS=10000,10000,10000,10000 +MULTI_HOT_SIZES=5,5,5,5 +EMBEDDING_DIM=16 +NUM_GPUS=(1 4) +OPTIMIZER_TYPE=("sgd" "adam" "adagrad" "rowwise_adagrad") +BATCH_SIZE=32 +NUM_ITERATIONS=10 +THRESHOLD=4 +SCORE_STRATEGY=("timestamp" "lfu" "step") + +# Cache configurations +CACHING_MODES=("False" "True") +CACHE_CAPACITY_RATIO=0.3 # 30% cache capacity to trigger evictions + + +for num_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + for score_strategy in ${SCORE_STRATEGY[@]}; do + echo "" + echo "----------------------------------------" + echo "Test: Storage-Only | GPUs: $num_gpus | Optimizer: $optimizer_type" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + ./test/unit_tests/test_embedding_admission.py \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $BATCH_SIZE \ + --num-iterations $NUM_ITERATIONS \ + --threshold $THRESHOLD \ + --score-strategy ${score_strategy} || exit 1 + done + done +done + +for num_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + for score_strategy in ${SCORE_STRATEGY[@]}; do + echo "" + echo "----------------------------------------" + echo "Test: Cache+Storage | GPUs: $num_gpus | Optimizer: $optimizer_type | Cache Ratio: $CACHE_CAPACITY_RATIO" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + ./test/unit_tests/test_embedding_admission.py \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $BATCH_SIZE \ + --num-iterations $NUM_ITERATIONS \ + --threshold $THRESHOLD \ + --caching \ + --cache-capacity-ratio $CACHE_CAPACITY_RATIO \ + --score-strategy ${score_strategy} || exit 1 + done + done +done + + + +# High-frequency test: more iterations to test frequency accumulation + +HIGH_FREQ_ITERATIONS=50 +for caching_mode in "without-cache" "with-cache"; do + echo "" + echo "----------------------------------------" + echo "Test: High Frequency ($HIGH_FREQ_ITERATIONS iters) | Mode: $caching_mode" + echo "----------------------------------------" + if [ "$caching_mode" = "without-cache" ]; then + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_embedding_admission.py \ + --num-embedding-collections 1 \ + --num-embeddings 5000 \ + --multi-hot-sizes 3 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type sgd \ + --batch-size 16 \ + --num-iterations $HIGH_FREQ_ITERATIONS \ + --threshold $THRESHOLD || exit 1 + else + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_embedding_admission.py \ + --num-embedding-collections 1 \ + --num-embeddings 5000 \ + --multi-hot-sizes 3 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type sgd \ + --batch-size 16 \ + --num-iterations $HIGH_FREQ_ITERATIONS \ + --threshold $THRESHOLD \ + --caching \ + --cache-capacity-ratio 0.4 || exit 1 + fi +done + + + +EVICTION_CACHE_RATIO_1=0.08 # 2% - Very small cache +EVICTION_BATCH_SIZE_1=64 # Large batch size +EVICTION_ITERATIONS_1=25 # Many iterations + +for optimizer_type in "sgd" "adam"; do + echo "" + echo "----------------------------------------" + echo "Test: Ultra-small Cache | Optimizer: $optimizer_type | Cache Ratio: $EVICTION_CACHE_RATIO_1" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_embedding_admission.py \ + --num-embedding-collections 1 \ + --num-embeddings 10000 \ + --multi-hot-sizes 5 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $EVICTION_BATCH_SIZE_1 \ + --num-iterations $EVICTION_ITERATIONS_1 \ + --threshold $THRESHOLD \ + --caching \ + --cache-capacity-ratio $EVICTION_CACHE_RATIO_1 || exit 1 +done + diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py index 180144013..06383a277 100644 --- a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.py @@ -13,658 +13,829 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse +import math import os import random import shutil -import sys -from typing import List +from itertools import product +from typing import Any, Dict, List, Tuple +import click import torch import torch.distributed as dist -import torchrec -from debug import Debugger +import torch.nn as nn from dynamicemb import ( - DynamicEmbCheckMode, - DynamicEmbDump, - DynamicEmbInitializerArgs, - DynamicEmbInitializerMode, - DynamicEmbLoad, DynamicEmbScoreStrategy, DynamicEmbTableOptions, + FrequencyAdmissionStrategy, ) -from dynamicemb.incremental_dump import get_score, set_score -from dynamicemb.planner import ( - DynamicEmbeddingEnumerator, - DynamicEmbeddingShardingPlanner, - DynamicEmbParameterConstraints, -) -from dynamicemb.shard import DynamicEmbeddingCollectionSharder -from fbgemm_gpu.split_embedding_configs import EmbOptimType -from torch.distributed.elastic.multiprocessing.errors import record -from torch.distributed.optim import ( - _apply_optimizer_in_backward as apply_optimizer_in_backward, -) -from torchrec.distributed.comm import get_local_size -from torchrec.distributed.fbgemm_qcomm_codec import ( - CommType, - QCommsConfig, - get_qcomm_codecs_registry, -) -from torchrec.distributed.model_parallel import ( - DefaultDataParallelWrapper, - DistributedModelParallel, +from dynamicemb.dump_load import ( + DynamicEmbDump, + DynamicEmbLoad, + find_sharded_modules, + get_dynamic_emb_module, ) -from torchrec.distributed.planner import ParameterConstraints, Topology -from torchrec.distributed.planner.storage_reservations import ( - HeuristicalStorageReservation, +from dynamicemb.dynamicemb_config import ( + DynamicEmbInitializerArgs, + DynamicEmbInitializerMode, ) -from torchrec.distributed.types import BoundsCheckMode, ShardingType - - -def str2bool(v): - if isinstance(v, bool): - return v - if v.lower() in ("yes", "true", "t", "y", "1"): - return True - elif v.lower() in ("no", "false", "f", "n", "0"): - return False +from dynamicemb.embedding_admission import KVCounter +from dynamicemb.get_planner import get_planner +from dynamicemb.key_value_table import batched_export_keys_values +from dynamicemb.scored_hashtable import ScoreArg, ScorePolicy +from dynamicemb.shard import DynamicEmbeddingCollectionSharder +from dynamicemb.types import AdmissionStrategy +from dynamicemb.utils import TORCHREC_TYPES +from fbgemm_gpu.split_embedding_configs import EmbOptimType, SparseType +from torchrec import DataType +from torchrec.distributed.model_parallel import DistributedModelParallel +from torchrec.modules.embedding_configs import EmbeddingConfig +from torchrec.modules.embedding_modules import EmbeddingCollection +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def idx_to_name(embedding_collection_idx: int, embedding_idx: int) -> Tuple[str, str]: + return ( + f"feature_{embedding_collection_idx}_{embedding_idx}", + f"embedding_collection_{embedding_collection_idx}_{embedding_idx}", + ) + + +def get_optimizer_kwargs(optimizer_type: str) -> Dict[str, Any]: + if optimizer_type == "sgd": + return {"optimizer": EmbOptimType.SGD, "lr": 1e-3} + elif optimizer_type == "adam": + return {"optimizer": EmbOptimType.ADAM, "lr": 1e-3} + elif optimizer_type == "adagrad": + return {"optimizer": EmbOptimType.EXACT_ADAGRAD, "lr": 1e-3} + elif optimizer_type == "rowwise_adagrad": + return {"optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD, "lr": 1e-3} else: - raise argparse.ArgumentTypeError("Boolean value expected.") - - -def table_idx_to_name(i): - return f"t_{i}" - + raise ValueError("unknown optimizer type") -def feature_idx_to_name(i): - return f"cate_{i}" - -def get_comm_precission(precision_str): - if precision_str == "fp32": - return CommType.FP32 - elif precision_str == "fp16": - return CommType.FP16 - elif precision_str == "bf16": - return CommType.BF16 - elif precision_str == "fp8": - return CommType.FP8 +def get_score_strategy(score_strategy_str: str) -> DynamicEmbScoreStrategy: + if score_strategy_str == "timestamp": + return DynamicEmbScoreStrategy.TIMESTAMP + elif score_strategy_str == "step": + return DynamicEmbScoreStrategy.STEP + elif score_strategy_str == "lfu": + return DynamicEmbScoreStrategy.LFU else: - raise ValueError("unknown comm precision type") - - -class CustomizedScore: - def __init__(self, table_names: List[int]): - self.table_names_ = table_names - self.steps_: Dict[str, int] = {table_name: 1 for table_name in table_names} + raise ValueError(f"Invalid score strategy: {score_strategy_str}") - def get(self, table_name: str): - assert table_name in self.table_names_ - ret = self.steps_[table_name] - self.steps_[table_name] += 1 - return ret - -def get_planner(args, device, eb_configs): - dict_const = {} - for i in range(args.num_embedding_table): - if ( - args.data_parallel_embeddings is not None - and i in args.data_parallel_embeddings - ): - const = ParameterConstraints( - sharding_types=[ShardingType.DATA_PARALLEL.value], - # min_partition=2, - pooling_factors=[args.multi_hot_sizes[i]], - num_poolings=[1], - enforce_hbm=True, - bounds_check_mode=BoundsCheckMode.NONE, - ) +def update_scores( + score_strategy: str, + expect_scores: Dict[int, int], + key: int, + step: int, +): + if score_strategy == "step": + expect_scores[key] = step + elif score_strategy == "lfu": + if key not in expect_scores: + expect_scores[key] = 1 else: - use_dynamicemb = True if i < args.dynamicemb_num else False - const = DynamicEmbParameterConstraints( - sharding_types=[ - ShardingType.ROW_WISE.value, - # ShardingType.COLUMN_WISE.value, - # ShardingType.ROW_WISE.value, - # ShardingType.TABLE_ROW_WISE.value, - # ShardingType.TABLE_COLUMN_WISE.value, - ], - # min_partition=2, - pooling_factors=[args.multi_hot_sizes[i]], - num_poolings=[1], - enforce_hbm=True, - bounds_check_mode=BoundsCheckMode.NONE, - use_dynamicemb=use_dynamicemb, - dynamicemb_options=DynamicEmbTableOptions( - global_hbm_for_values=1024**3, - initializer_args=DynamicEmbInitializerArgs( - mode=DynamicEmbInitializerMode.DEBUG, - ), - safe_check_mode=DynamicEmbCheckMode.WARNING, - score_strategy=args.score_strategies, - ), - ) - dict_const[table_idx_to_name(i)] = const - - topology = Topology( - local_world_size=get_local_size(), - world_size=dist.get_world_size(), - compute_device=device.type, - hbm_cap=args.hbm_cap, - ddr_cap=1024 * 1024 * 1024 * 1024, - # simulate DynamicEmb table is big and have other table - # hbm_cap=int(340000000/dist.get_world_size()), - # ddr_cap=1, - intra_host_bw=args.intra_host_bw, - inter_host_bw=args.inter_host_bw, - ) - - enumerator = DynamicEmbeddingEnumerator( - topology=topology, - # batch_size=args.batch_size, - constraints=dict_const, - ) - - return DynamicEmbeddingShardingPlanner( - eb_configs=eb_configs, - topology=topology, - constraints=dict_const, - batch_size=args.batch_size, - enumerator=enumerator, - # # If experience OOM, increase the percentage. see - # # https://pytorch.org/torchrec/torchrec.distributed.planner.html#torchrec.distributed.planner.storage_reservations.HeuristicalStorageReservation - storage_reservation=HeuristicalStorageReservation(percentage=0.05), - debug=True, - ) + expect_scores[key] = expect_scores[key] + 1 + else: + return -def init_fn(x: torch.Tensor): - with torch.no_grad(): - x.fill_(2.0) +def update_scores( + score_strategy: str, + expect_scores: Dict[int, int], + key: int, + step: int, +): + if score_strategy == "step": + expect_scores[key] = step + elif score_strategy == "lfu": + if key not in expect_scores: + expect_scores[key] = 1 + else: + expect_scores[key] = expect_scores[key] + 1 + else: + return def generate_sparse_feature( - feature_num, num_embeddings_list, multi_hot_sizes, local_batch_size=50 + num_embedding_collections: int, + num_embeddings: List[int], + multi_hot_sizes: List[int], + rank: int, + world_size: int, + batch_size: int, + num_iterations: int, + score_strategy: str, + scores_collection: Dict[str, Dict[int, int]], + seed: int = 42, ): feature_batch = feature_num * local_batch_size - indices = [] - lengths = [] - - for i in range(feature_batch): - f = i // local_batch_size - cur_bag_size = random.randint(0, multi_hot_sizes[f]) - cur_bag = set({}) - while len(cur_bag) < cur_bag_size: - cur_bag.add(random.randint(0, num_embeddings_list[f] - 1)) - - indices.extend(list(cur_bag)) - lengths.append(cur_bag_size) - - return torchrec.KeyedJaggedTensor( - keys=[feature_idx_to_name(feature_idx) for feature_idx in range(feature_num)], - values=torch.tensor( - indices, dtype=torch.int64 - ).cuda(), # key [0,1] on rank0, [2] on rank 1 - lengths=torch.tensor(lengths, dtype=torch.int64).cuda(), - ) - + batch_size_per_rank = batch_size // world_size + kjts = [] + all_kjts = [] + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, _ in enumerate(num_embeddings): + _, table_name = idx_to_name(embedding_collection_id, embedding_id) + scores_collection[table_name] = {} + step = 0 + for _ in range(num_iterations): + step += 1 + cur_indices, cur_lengths = [], [] + all_indices, all_lengths = [], [] + keys = [] + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + feature_name, table_name = idx_to_name( + embedding_collection_id, embedding_id + ) + expected_scores: Dict[int, int] = scores_collection[table_name] + for sample_id in range(batch_size): + hotness = random.randint( + 0, multi_hot_sizes[embedding_collection_id] + ) + indices = [random.randint(0, (1 << 63) - 1) for _ in range(hotness)] + all_indices.extend(indices) + all_lengths.append(hotness) + if sample_id // batch_size_per_rank == rank: + cur_indices.extend(indices) + cur_lengths.append(hotness) + for index in indices: + update_scores(score_strategy, expected_scores, index, step) + keys.append(feature_name) + kjts.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor(cur_indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(cur_lengths, dtype=torch.int64).cuda(), + ) + ) + return kjts -def run(args): - backend = "nccl" - dist.init_process_group(backend=backend) - local_rank = int(os.environ["LOCAL_RANK"]) - world_size = int(os.environ["WORLD_SIZE"]) - torch.cuda.set_device(local_rank) - device = torch.device(f"cuda:{local_rank}") - - all_table_names = [ - table_idx_to_name(feature_idx) - for feature_idx in range(args.num_embedding_table) - ] - - eb_configs = [ - torchrec.EmbeddingConfig( - name=table_idx_to_name(feature_idx), - embedding_dim=args.embedding_dim, - num_embeddings=args.num_embeddings_per_feature[feature_idx], - feature_names=[feature_idx_to_name(feature_idx)], +def generate_sparse_feature( + num_embedding_collections: int, + num_embeddings: List[int], + multi_hot_sizes: List[int], + rank: int, + world_size: int, + batch_size: int, + num_iterations: int, + score_strategy: str, + scores_collection: Dict[str, Dict[int, int]], + seed: int = 42, +): + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + batch_size_per_rank = batch_size // world_size + kjts = [] + all_kjts = [] + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, _ in enumerate(num_embeddings): + _, table_name = idx_to_name(embedding_collection_id, embedding_id) + scores_collection[table_name] = {} + step = 0 + for _ in range(num_iterations): + step += 1 + cur_indices, cur_lengths = [], [] + all_indices, all_lengths = [], [] + keys = [] + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + feature_name, table_name = idx_to_name( + embedding_collection_id, embedding_id + ) + expected_scores: Dict[int, int] = scores_collection[table_name] + for sample_id in range(batch_size): + hotness = random.randint( + 0, multi_hot_sizes[embedding_collection_id] + ) + indices = [random.randint(0, (1 << 63) - 1) for _ in range(hotness)] + all_indices.extend(indices) + all_lengths.append(hotness) + if sample_id // batch_size_per_rank == rank: + cur_indices.extend(indices) + cur_lengths.append(hotness) + for index in indices: + update_scores(score_strategy, expected_scores, index, step) + keys.append(feature_name) + kjts.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor(cur_indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(cur_lengths, dtype=torch.int64).cuda(), + ) ) - for feature_idx in range(args.num_embedding_table) - ] - ebc = torchrec.EmbeddingCollection( - device=torch.device("meta"), - tables=eb_configs, - ) - - if args.use_torch_opt: - optimizer_kwargs = { - "lr": args.learning_rate, - "betas": (args.beta1, args.beta2), - "weight_decay": args.weight_decay, - "eps": args.eps, - } - if args.optimizer_type == "sgd": - embedding_optimizer = torch.optim.SGD - elif args.optimizer_type == "adam": - embedding_optimizer = torch.optim.Adam - else: - raise ValueError("unknown optimizer type") - else: - optimizer_kwargs = { - "learning_rate": args.learning_rate, - "beta1": args.beta1, - "beta2": args.beta2, - "weight_decay": args.weight_decay, - "eps": args.eps, - } - if args.optimizer_type == "sgd": - optimizer_kwargs["optimizer"] = EmbOptimType.EXACT_SGD - elif args.optimizer_type == "adam": - optimizer_kwargs["optimizer"] = EmbOptimType.ADAM - elif args.optimizer_type == "exact_adagrad": - optimizer_kwargs["optimizer"] = EmbOptimType.EXACT_ADAGRAD - elif args.optimizer_type == "exact_row_wise_adagrad": - optimizer_kwargs["optimizer"] = EmbOptimType.EXACT_ROWWISE_ADAGRAD - else: - raise ValueError("unknown optimizer type") - - planner = get_planner(args, device, eb_configs) - - qcomm_forward_precision = get_comm_precission(args.fwd_a2a_precision) - qcomm_backward_precision = get_comm_precission(args.fwd_a2a_precision) - qcomm_codecs_registry = ( - get_qcomm_codecs_registry( - qcomms_config=QCommsConfig( - # pyre-ignore - forward_precision=qcomm_forward_precision, - # pyre-ignore - backward_precision=qcomm_backward_precision, + all_kjts.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor(all_indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(all_lengths, dtype=torch.int64).cuda(), ) ) - if backend == "nccl" - else None - ) + return kjts, keys, all_kjts + + +class TestModel(nn.Module): + def __init__( + self, + embedding_modules: List[EmbeddingCollection], + ): + super().__init__() + self.embedding_modules = nn.ModuleList(embedding_modules) + + def forward(self, kjt: KeyedJaggedTensor) -> torch.Tensor: + embeddings_dict = [ + embedding_module(kjt).wait() for embedding_module in self.embedding_modules + ] + embeddings = [] + for embedding_dict in embeddings_dict: + for embedding in embedding_dict.values(): + embeddings.append(embedding.values()) + return torch.cat(embeddings, dim=0) + + +DATA_TYPE_NUM_BITS: Dict[DataType, int] = { + DataType.FP32: 32, + DataType.FP16: 16, + DataType.BF16: 16, +} + + +def apply_dmp( + model: torch.nn.Module, + optimizer_kwargs: Dict[str, Any], + device: torch.device, + score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.LFU, + use_index_dedup: bool = False, + caching: bool = False, + cache_capacity_ratio: float = 0.5, + admit_strategy: AdmissionStrategy = None, +): + eb_configs = [] + dynamicemb_options_dict = {} + for n, m in model.named_modules(): + if type(m) in TORCHREC_TYPES: + eb_configs.extend(m.embedding_configs()) + for eb_config in eb_configs: + dim = eb_config.embedding_dim + tmp_type = eb_config.data_type + + embedding_type_bytes = DATA_TYPE_NUM_BITS[tmp_type] / 8 + emb_num_embeddings = ( + eb_config.num_embeddings * cache_capacity_ratio + if caching + else eb_config.num_embeddings + ) + emb_num_embeddings_next_power_of_2 = 2 ** math.ceil( + math.log2(emb_num_embeddings) + ) # HKV need embedding vector num is power of 2 + + # Calculate optimizer state dimension + from dynamicemb.dynamicemb_config import ( + data_type_to_dtype, + get_optimizer_state_dim, + ) + from dynamicemb_extensions import OptimizerType - if not args.use_torch_opt: - sharder = DynamicEmbeddingCollectionSharder( - qcomm_codecs_registry=qcomm_codecs_registry, - fused_params=optimizer_kwargs, - use_index_dedup=args.use_index_dedup, - ) - else: - sharder = DynamicEmbeddingCollectionSharder( - qcomm_codecs_registry=qcomm_codecs_registry, - use_index_dedup=args.use_index_dedup, - ) + # Map fbgemm EmbOptimType to dynamicemb OptimizerType + emb_opt_type = ( + optimizer_kwargs.get("optimizer") if optimizer_kwargs else None + ) + opt_type_map = { + EmbOptimType.EXACT_ROWWISE_ADAGRAD: OptimizerType.RowWiseAdaGrad, + EmbOptimType.SGD: OptimizerType.SGD, + EmbOptimType.EXACT_SGD: OptimizerType.SGD, + EmbOptimType.ADAM: OptimizerType.Adam, + EmbOptimType.EXACT_ADAGRAD: OptimizerType.AdaGrad, + } + opt_type = opt_type_map.get(emb_opt_type) if emb_opt_type else None + # Convert torchrec DataType to torch.dtype + torch_dtype = data_type_to_dtype(tmp_type) + optimizer_state_dim = ( + get_optimizer_state_dim(opt_type, dim, torch_dtype) + if opt_type + else 0 + ) - plan = planner.collective_plan(ebc, [sharder], dist.GroupMember.WORLD) + # Include optimizer state in HBM calculation + total_hbm_need = ( + embedding_type_bytes + * (dim + optimizer_state_dim) + * emb_num_embeddings_next_power_of_2 + ) - if args.use_torch_opt: - apply_optimizer_in_backward( - embedding_optimizer, - ebc.parameters(), - optimizer_kwargs, - ) + admission_counter = KVCounter( + max(1024 * 1024, emb_num_embeddings_next_power_of_2 // 4) + ) + dynamicemb_options_dict[eb_config.name] = DynamicEmbTableOptions( + global_hbm_for_values=total_hbm_need, + score_strategy=score_strategy, + initializer_args=DynamicEmbInitializerArgs( + mode=DynamicEmbInitializerMode.CONSTANT, + value=1e-1, + ), + bucket_capacity=emb_num_embeddings_next_power_of_2, + max_capacity=emb_num_embeddings_next_power_of_2, + caching=caching, + local_hbm_for_values=1024**3, + admit_strategy=admit_strategy, + admission_counter=admission_counter, + ) + planner = get_planner( + eb_configs, + {}, + dynamicemb_options_dict, + device, + ) - data_parallel_wrapper = DefaultDataParallelWrapper( - allreduce_comm_precision=args.allreduce_precision + fused_params = {} + fused_params["output_dtype"] = SparseType.FP32 + fused_params.update(optimizer_kwargs) + + sharder = DynamicEmbeddingCollectionSharder( + fused_params=fused_params, + use_index_dedup=use_index_dedup, ) - model = DistributedModelParallel( - module=ebc, + plan = planner.collective_plan(model, [sharder], dist.GroupMember.WORLD) + + # Same usage of TorchREC + dmp = DistributedModelParallel( + module=model, device=device, # pyre-ignore sharders=[sharder], plan=plan, - data_parallel_wrapper=data_parallel_wrapper, ) + return dmp - customized_scores = CustomizedScore(all_table_names) - ret: Dict[str, Dict[str, int]] = get_score(model) - prefix_path = "model" - - if ret is None: - return - else: - assert len(ret) == 1 and prefix_path in ret - scores_to_set: Dict[str, int] = {} - for i in range(args.num_embedding_table): - if args.score_strategies == DynamicEmbScoreStrategy.CUSTOMIZED: - scores_to_set[all_table_names[i]] = customized_scores.get( - all_table_names[i] +def create_model( + num_embedding_collections: int, + num_embeddings: List[int], + embedding_dim: int, + optimizer_kwargs: Dict[str, Any], + score_strategy: DynamicEmbScoreStrategy = DynamicEmbScoreStrategy.LFU, + use_index_dedup: bool = False, + caching: bool = False, + cache_capacity_ratio: float = 0.5, + admit_strategy: AdmissionStrategy = None, +): + ebc_list = [] + for embedding_collection_id in range(num_embedding_collections): + eb_configs = [] + for embedding_id, num_embedding in enumerate(num_embeddings): + feature_name, embedding_name = idx_to_name( + embedding_collection_id, embedding_id ) - set_score(model, {prefix_path: scores_to_set}) - - if local_rank == 0 and args.print_sharding_plan: - for collectionkey, plans in model._plan.plan.items(): - print(collectionkey) - for table_name, plan in plans.items(): - print(table_name, "\n", plan, "\n") - - def optimizer_with_params(): - if args.optimizer_type == "sgd": - return lambda params: torch.optim.SGD(params, lr=args.learning_rate) - elif args.optimizer_type == "adagrad": - return lambda params: torch.optim.Adagrad( - params, lr=args.learning_rate, eps=args.eps + eb_configs.append( + EmbeddingConfig( + name=embedding_name, + embedding_dim=embedding_dim, + num_embeddings=num_embedding, + feature_names=[feature_name], + data_type=DataType.FP32, + ) ) - elif args.optimizer_type == "rowwise_adagrad": - return lambda params: torch.optim.Adagrad( - params, lr=args.learning_rate, eps=args.eps + ebc_list.append( + EmbeddingCollection( + device=torch.device("meta"), + tables=eb_configs, ) - else: - raise ValueError("unknown optimizer type") - - Debugger() - - for i in range(args.num_iterations): - sparse_feature = generate_sparse_feature( - feature_num=args.num_embedding_table, - num_embeddings_list=args.num_embeddings_per_feature, - multi_hot_sizes=args.multi_hot_sizes, - local_batch_size=args.batch_size // world_size, ) - ret = model(sparse_feature) # => this is awaitable - - feature_names = [] - jagged_tensors = [] - for k, v in ret.items(): - feature_names.append(k) - jagged_tensors.append(v.values()) - - concatenated_tensor = torch.cat(jagged_tensors, dim=0) - reduced_tensor = concatenated_tensor.sum() - reduced_tensor.backward() - - scores_to_set: Dict[str, int] = {} - for i in range(args.num_embedding_table): - if args.score_strategies == DynamicEmbScoreStrategy.CUSTOMIZED: - scores_to_set[all_table_names[i]] = customized_scores.get( - all_table_names[i] - ) - - set_score(model, {prefix_path: scores_to_set}) - - DynamicEmbDump("debug_weight", model, optim=True) - DynamicEmbLoad("debug_weight", model, optim=True) - - table_names = {"model": ["t_0"]} - DynamicEmbDump("debug_weight_t0", model, table_names=table_names, optim=True) - DynamicEmbLoad("debug_weight_t0", model, table_names=table_names, optim=False) - - table_names = {"model": ["t_1"]} - DynamicEmbDump("debug_weight_t1", model, table_names=table_names, optim=False) - DynamicEmbLoad("debug_weight_t1", model, table_names=table_names, optim=False) + model = TestModel( + embedding_modules=ebc_list, + ) - dist.barrier() + model = apply_dmp( + model, + optimizer_kwargs, + torch.device(f"cuda:{torch.cuda.current_device()}"), + score_strategy=score_strategy, + use_index_dedup=use_index_dedup, + caching=caching, + cache_capacity_ratio=cache_capacity_ratio, + admit_strategy=admit_strategy, + ) + return model - if local_rank == 0: - shutil.rmtree("debug_weight") - shutil.rmtree("debug_weight_t0") - shutil.rmtree("debug_weight_t1") - dist.barrier() - dist.destroy_process_group() +def check_counter_table_checkpoint(x, y): + device = torch.cuda.current_device() + tables_x = get_dynamic_emb_module(x) + tables_y = get_dynamic_emb_module(y) + for table_x, table_y in zip(tables_x, tables_y): + for cnt_tx, cnt_ty in zip( + table_x._admission_counter, table_y._admission_counter + ): + assert cnt_tx.table_.size() == cnt_ty.table_.size() + + for keys, named_scores in cnt_tx._batched_export_keys_scores( + cnt_tx.table_.score_names_, torch.device(f"cuda:{device}") + ): + if keys.numel() == 0: + continue + freq_name = cnt_tx.table_.score_names_[0] + frequencies = named_scores[freq_name] + + score_args_lookup = [ + ScoreArg( + name=freq_name, + value=torch.zeros_like(frequencies), + policy=ScorePolicy.CONST, + is_return=True, + ) + ] + founds = torch.empty( + keys.numel(), dtype=torch.bool, device=device + ).fill_(False) + + cnt_ty.lookup(keys, score_args_lookup, founds) + + assert torch.equal(frequencies, score_args_lookup) + + +@click.command() +@click.option("--num-embedding-collections", type=int, required=True) +@click.option("--num-embeddings", type=str, required=True) +@click.option("--multi-hot-sizes", type=str, required=True) +@click.option("--embedding-dim", type=int, required=True) +@click.option("--save-path", type=str, required=True) +@click.option( + "--optimizer-type", + type=click.Choice(["sgd", "adam", "adagrad", "rowwise_adagrad"]), + required=True, +) +@click.option("--mode", type=click.Choice(["load", "dump"]), required=True) +@click.option( + "--score-strategy", + type=click.Choice(["timestamp", "step", "lfu"]), + required=True, +) +@click.option("--optim", type=bool, required=True) +@click.option("--counter", type=bool, required=True) +def test_model_load_dump( + num_embedding_collections: int, + num_embeddings: str, + multi_hot_sizes: str, + embedding_dim: int, + optimizer_type: str, + score_strategy: str, + mode: str, + save_path: str, + optim: bool, + counter: bool, + batch_size: int = 128, + num_iterations: int = 10, +): + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) -@record -def main(argv: List[str]) -> None: - parser = argparse.ArgumentParser(description="DynamicEmb dump load test") - parser.add_argument( - "--epochs", - type=int, - default=1, - help="number of epochs to train", - ) - parser.add_argument( - "--batch_size", - type=int, - default=16, - help="batch size to use for training", - ) - parser.add_argument( - "--num_iterations", - type=int, - default=100, - help="number of iterations", - ) - parser.add_argument( - "--num_embeddings_per_feature", - type=str, - default="65536,32768,409600,81920", - help="Comma separated max_ind_size per sparse feature. The number of embeddings" - " in each embedding table. 26 values are expected for the Criteo dataset.", - ) - parser.add_argument( - "--multi_hot_sizes", - type=str, - default="16,8,20,1", - help="Comma separated multihot size per sparse feature. 26 values are expected for the Criteo dataset.", - ) - parser.add_argument( - "--print_sharding_plan", - action="store_true", - help="Print the sharding plan used for each embedding table.", - ) - parser.add_argument( - "--fwd_a2a_precision", - type=str, - default="fp32", - choices=["fp32", "fp16", "bf16", "fp8"], - ) - parser.add_argument( - "--bck_a2a_precision", - type=str, - default="fp32", - choices=["fp32", "fp16", "bf16", "fp8"], - ) - parser.add_argument( - "--allreduce_precision", - type=str, - default="fp16", - choices=["fp16", "bf16", "fp32"], - ) - parser.add_argument( - "--embedding_dim", - type=int, - default=128, - help="Size of each embedding.", - ) - parser.add_argument( - "--dense_in_features", - type=int, - default=13, - help="dense_in_features.", - ) - parser.add_argument( - "--dense_arch_layer_sizes", - type=str, - default="512,256,128", - help="Comma separated layer sizes for dense arch.", - ) - parser.add_argument( - "--over_arch_layer_sizes", - type=str, - default="512,512,256,1", - help="Comma separated layer sizes for over arch.", - ) - parser.add_argument( - "--dcn_num_layers", - type=int, - default=3, - help="Number of DCN layers in interaction layer (only on dlrm with DCN).", - ) - parser.add_argument( - "--dcn_low_rank_dim", - type=int, - default=512, - help="Low rank dimension for DCN in interaction layer (only on dlrm with DCN).", - ) - parser.add_argument( - "--optimizer_type", - type=str, - default="adam", - choices=["sgd", "adam", "exact_adagrad", "row_wise_adagrad"], - help="optimizer type.", - ) + num_embeddings = [int(v) for v in num_embeddings.split(",")] + multi_hot_sizes = [int(v) for v in multi_hot_sizes.split(",")] - parser.add_argument( - "--learning_rate", - type=float, - default=0.1, - help="Learning rate.", - ) + for num_embedding, multi_hot_size in zip(num_embeddings, multi_hot_sizes): + if batch_size * num_iterations * multi_hot_size > num_embedding: + raise ValueError( + "batch_size * num_iterations * multi_hot_size > num_embedding, this may lead to eviction of dynamicemb and cause test fail" + ) - parser.add_argument( - "--beta1", - type=float, - default=0.9, - help="beta1.", - ) + optimizer_kwargs = get_optimizer_kwargs(optimizer_type) + score_strategy_ = get_score_strategy(score_strategy) + + ref_model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=score_strategy_, + admit_strategy=FrequencyAdmissionStrategy( + threshold=2 if counter else 1, + ), + ) + + expect_scores_collection: Dict[str, Dict[int, int]] = {} + kjts, feature_names, all_kjts = generate_sparse_feature( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + multi_hot_sizes=multi_hot_sizes, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + batch_size=batch_size, + num_iterations=num_iterations, + score_strategy=score_strategy, + scores_collection=expect_scores_collection, + ) + + for kjt in kjts: + ret = ref_model(kjt) + loss = ( + ret.sum() * dist.get_world_size() + ) # scale the loss by world size to make the gradients consistent between different gpu settings + loss.backward() + + if mode == "dump": + shutil.rmtree(save_path, ignore_errors=True) + DynamicEmbDump(save_path, ref_model, optim=optim, counter=counter) + + if mode == "load": + model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=score_strategy_, + admit_strategy=FrequencyAdmissionStrategy( + threshold=2 if counter else 1, + ), + ) - parser.add_argument( - "--beta2", - type=float, - default=0.999, - help="beta1.", - ) + DynamicEmbLoad(save_path, model, optim=optim, counter=counter) + + if counter: + check_counter_table_checkpoint(model, ref_model) + + table_name_to_key_score_dict = {} + table_name_to_visited_key_dict = {} + for _, _, sharded_module in find_sharded_modules(model): + dynamic_emb_modules = get_dynamic_emb_module(sharded_module) + for dynamic_emb_module in dynamic_emb_modules: + for table_name, table, counter_table in zip( + dynamic_emb_module.table_names, + dynamic_emb_module.tables, + dynamic_emb_module._admission_counter, + ): + key_to_score = {} + visited_keys = set({}) + for batched_key, _, _, batched_score in batched_export_keys_values( + table.table, torch.device(f"cpu") + ): + for key, score in zip( + batched_key.tolist(), batched_score.tolist() + ): + key_to_score[key] = score + + for ( + keys, + named_scores, + ) in counter_table.table_._batched_export_keys_scores( + counter_table.table_.score_names_, torch.device(f"cpu") + ): + if keys.numel() == 0: + continue + for key in keys.tolist(): + visited_keys.add(key) + + table_name_to_key_score_dict[table_name] = key_to_score + table_name_to_visited_key_dict[table_name] = visited_keys + + for embedding_collection_idx, embedding_idx in product( + range(num_embedding_collections), range(len(num_embeddings)) + ): + feature_name, table_name = idx_to_name( + embedding_collection_idx, embedding_idx + ) + key_to_score_dict = table_name_to_key_score_dict[table_name].copy() + expect_scores = expect_scores_collection[table_name] + visited_keys = table_name_to_visited_key_dict[table_name] + + if score_strategy == "step" or score_strategy == "lfu": + for kjt in reversed(all_kjts): + keys = kjt[feature_name].values().tolist() + for key in keys: + if key % world_size == rank and key not in visited_keys: + assert ( + key in key_to_score_dict + ), f"Key {key} must exist in table of rank {rank}." + assert ( + key_to_score_dict[key] == expect_scores[key] + ), f"Expect {key_to_score_dict[key]} = {expect_scores[key]}" + # The idea is that the score of a newer key is greater than that of an older key. Therefore, I iterate through the input in reverse order and track the minimum score encountered. For each batch, the score should be lower than the minimum score from the previous batch. To avoid issues caused by duplicate keys, every time I access a key, I set its score to -inf. This ensures that if that key appears again, its score will be sufficiently small to remain below the minimum score. + elif score_strategy == "timestamp": + min_score = float("inf") + lasted_min_score = float("inf") + for kjt in reversed(all_kjts): + keys = kjt[feature_name].values().tolist() + for key in keys: + if key % world_size == rank and key not in visited_keys: + assert ( + key in key_to_score_dict + ), f"Key {key} must exist in table of rank {rank}." + else: + continue + + assert ( + key_to_score_dict[key] <= min_score + ), f"key {key} score {key_to_score_dict[key]} should be < min_score {min_score}" + lasted_min_score = min(lasted_min_score, key_to_score_dict[key]) + visited_keys.add(key) + + min_score = lasted_min_score + lasted_min_score = min_score + + else: + raise RuntimeError("Not supported score strategy.") + + if optim: + for kjt in kjts: + ret = model(kjt) + ret.sum().backward() + ref_ret = ref_model(kjt) + ref_ret.sum().backward() + + ref_model = ref_model.eval() + model = model.eval() + + with torch.inference_mode(): + for kjt in kjts: + ret = model(kjt) + ref_ret = ref_model(kjt) + assert torch.allclose(ret, ref_ret) + + +@click.command() +@click.option("--num-embedding-collections", type=int, required=True) +@click.option("--num-embeddings", type=str, required=True) +@click.option("--multi-hot-sizes", type=str, required=True) +@click.option("--embedding-dim", type=int, required=True) +@click.option("--save-path", type=str, required=True) +@click.option( + "--optimizer-type", + type=click.Choice(["sgd", "adam", "adagrad", "rowwise_adagrad"]), + required=True, +) +@click.option("--mode", type=click.Choice(["load", "dump"]), required=True) +@click.option( + "--score-strategy", + type=click.Choice(["timestamp", "step", "lfu"]), + required=True, +) +@click.option("--optim", type=bool, required=True) +def test_model_load_dump( + num_embedding_collections: int, + num_embeddings: str, + multi_hot_sizes: str, + embedding_dim: int, + optimizer_type: str, + score_strategy: str, + mode: str, + save_path: str, + optim: bool, + batch_size: int = 128, + num_iterations: int = 10, +): + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) - parser.add_argument( - "--eps", - type=float, - default=0.001, - help="eps.", - ) + num_embeddings = [int(v) for v in num_embeddings.split(",")] + multi_hot_sizes = [int(v) for v in multi_hot_sizes.split(",")] - parser.add_argument( - "--weight_decay", - type=float, - default=0, - help="weight_decay.", - ) - parser.add_argument( - "--use_torch_opt", - action="store_true", - help="if is true , use torch register optimizer , or use torchrec", - ) - parser.add_argument( - "--allow_tf32", - action="store_true", - help="Enable TensorFloat-32 mode for matrix multiplications on A100 (or newer) GPUs.", - ) - parser.add_argument( - "--data_parallel_embeddings", - type=str, - default=None, - help="Comma separated data parallel embedding table ids.", - ) - parser.add_argument( - "--platform", - type=str, - default="a100", - choices=["a100", "h100", "h200"], - help="Platform, has different system spec", - ) - parser.add_argument( - "--bmlp_overlap", - action="store_true", - help="overlap bottom mlp", - ) - parser.add_argument( - "--enable_cuda_graph", - action="store_true", - help="enable cuda_graph", - ) - parser.add_argument( - "--skip_h2d", - action="store_true", - help="no input to the training pipeline", - ) - parser.add_argument( - "--skip_input_dist", - action="store_true", - help="skip the input distribution", - ) - parser.add_argument( - "--disable_pipeline", - action="store_true", - help="disable pipeline", - ) - parser.add_argument( - "--dynamicemb_num", - type=int, - default=2, - help="Number of dynamic embedding tables.", - ) - parser.add_argument( - "--use_index_dedup", - type=str2bool, - default=True, - help="Use index deduplication (default: True).", - ) + for num_embedding, multi_hot_size in zip(num_embeddings, multi_hot_sizes): + if batch_size * num_iterations * multi_hot_size > num_embedding: + raise ValueError( + "batch_size * num_iterations * multi_hot_size > num_embedding, this may lead to eviction of dynamicemb and cause test fail" + ) - parser.add_argument( - "--score_type", - type=str, - default="timestamp", - choices=["timestamp", "step", "custimized"], - help="score type string", - ) + optimizer_kwargs = get_optimizer_kwargs(optimizer_type) + score_strategy_ = get_score_strategy(score_strategy) - args = parser.parse_args() - - args.num_embeddings_per_feature = [ - int(v) for v in args.num_embeddings_per_feature.split(",") - ] - args.multi_hot_sizes = [int(v) for v in args.multi_hot_sizes.split(",")] - args.dense_arch_layer_sizes = [ - int(v) for v in args.dense_arch_layer_sizes.split(",") - ] - args.over_arch_layer_sizes = [int(v) for v in args.over_arch_layer_sizes.split(",")] - args.data_parallel_embeddings = ( - None - if args.data_parallel_embeddings is None - else [int(v) for v in args.data_parallel_embeddings.split(",")] + ref_model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=score_strategy_, ) - args.num_embedding_table = len(args.num_embeddings_per_feature) - if args.embedding_dim % 4 != 0: - print( - f"INFO: args.embedding_dim = {args.embedding_dim} is not aligned with 4, which can't use TorchREC raw embedding table , so all embedding table is dynamic embedding table" + expect_scores_collection: Dict[str, Dict[int, int]] = {} + kjts, feature_names, all_kjts = generate_sparse_feature( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + multi_hot_sizes=multi_hot_sizes, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + batch_size=batch_size, + num_iterations=num_iterations, + score_strategy=score_strategy, + scores_collection=expect_scores_collection, + ) + + for kjt in kjts: + ret = ref_model(kjt) + loss = ( + ret.sum() * dist.get_world_size() + ) # scale the loss by world size to make the gradients consistent between different gpu settings + loss.backward() + + if mode == "dump": + shutil.rmtree(save_path, ignore_errors=True) + DynamicEmbDump(save_path, ref_model, optim=optim) + + if mode == "load": + model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=score_strategy_, ) - args.dynamicemb_num = args.num_embedding_table - - if args.platform == "a100": - args.intra_host_bw = 300e9 - args.inter_host_bw = 25e9 - args.hbm_cap = 80 * 1024 * 1024 * 1024 - elif args.platform == "h100": - args.intra_host_bw = 450e9 - args.inter_host_bw = 25e9 # TODO: need check - args.hbm_cap = 80 * 1024 * 1024 * 1024 - elif args.platform == "h200": - args.intra_host_bw = 450e9 - args.inter_host_bw = 450e9 - args.hbm_cap = 140 * 1024 * 1024 * 1024 - - if args.score_type == "timestamp": - args.score_strategies = DynamicEmbScoreStrategy.TIMESTAMP - elif args.score_type == "step": - args.score_strategies = DynamicEmbScoreStrategy.STEP - elif args.score_type == "custimized": - args.score_strategies = DynamicEmbScoreStrategy.CUSTOMIZED - - # Print all arguments - print("Arguments:") - for arg, value in vars(args).items(): - print(f"{arg}: {value}") - - run(args) + + DynamicEmbLoad(save_path, model, optim=optim) + + table_name_to_key_score_dict = {} + for _, _, sharded_module in find_sharded_modules(model): + dynamic_emb_modules = get_dynamic_emb_module(sharded_module) + for dynamic_emb_module in dynamic_emb_modules: + for table_name, table in zip( + dynamic_emb_module.table_names, dynamic_emb_module.tables + ): + key_to_score = {} + for batched_key, _, _, batched_score in batched_export_keys_values( + table.table, torch.device(f"cpu") + ): + for key, score in zip( + batched_key.tolist(), batched_score.tolist() + ): + key_to_score[key] = score + table_name_to_key_score_dict[table_name] = key_to_score + + for embedding_collection_idx, embedding_idx in product( + range(num_embedding_collections), range(len(num_embeddings)) + ): + feature_name, table_name = idx_to_name( + embedding_collection_idx, embedding_idx + ) + key_to_score_dict = table_name_to_key_score_dict[table_name].copy() + expect_scores = expect_scores_collection[table_name] + + if score_strategy == "step" or score_strategy == "lfu": + for kjt in reversed(all_kjts): + keys = kjt[feature_name].values().tolist() + for key in keys: + if key % world_size == rank: + assert ( + key in key_to_score_dict + ), f"Key {key} must exist in table of rank {rank}." + assert ( + key_to_score_dict[key] == expect_scores[key] + ), f"Expect {key_to_score_dict[key]} = {expect_scores[key]}" + # The idea is that the score of a newer key is greater than that of an older key. Therefore, I iterate through the input in reverse order and track the minimum score encountered. For each batch, the score should be lower than the minimum score from the previous batch. To avoid issues caused by duplicate keys, every time I access a key, I set its score to -inf. This ensures that if that key appears again, its score will be sufficiently small to remain below the minimum score. + elif score_strategy == "timestamp": + visited_keys = set({}) + min_score = float("inf") + lasted_min_score = float("inf") + for kjt in reversed(all_kjts): + keys = kjt[feature_name].values().tolist() + for key in keys: + if key % world_size == rank: + assert ( + key in key_to_score_dict + ), f"Key {key} must exist in table of rank {rank}." + else: + continue + + if key not in visited_keys: + assert ( + key_to_score_dict[key] <= min_score + ), f"key {key} score {key_to_score_dict[key]} should be < min_score {min_score}" + lasted_min_score = min( + lasted_min_score, key_to_score_dict[key] + ) + visited_keys.add(key) + + min_score = lasted_min_score + lasted_min_score = min_score + + else: + raise RuntimeError("Not supported score strategy.") + + if optim: + for kjt in kjts: + ret = model(kjt) + ret.sum().backward() + ref_ret = ref_model(kjt) + ref_ret.sum().backward() + + ref_model = ref_model.eval() + model = model.eval() + + with torch.inference_mode(): + for kjt in kjts: + ret = model(kjt) + ref_ret = ref_model(kjt) + assert torch.allclose(ret, ref_ret) if __name__ == "__main__": - main(sys.argv[1:]) + LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(LOCAL_RANK) + + dist.init_process_group(backend="nccl") + test_model_load_dump() + dist.destroy_process_group() diff --git a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.sh b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.sh index e1540352e..77c4fb887 100644 --- a/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.sh +++ b/corelib/dynamicemb/test/unit_tests/test_embedding_dump_load.sh @@ -1,61 +1,65 @@ #!/bin/bash set -e -export PYHKV_DEBUG=1 -export PYHKV_DEBUG_ITER=10 -export DYNAMICEMB_DUMP_LOAD_DEBUG=1 +NUM_EMBEDDING_COLLECTIONS=4 +NUM_EMBEDDINGS=1000000,1000000,1000000,1000000,1000000,1000000 +MULTI_HOT_SIZES=10,10,10,10,10,10 +NUM_GPUS=(1 4) +OPTIMIZER_TYPE=("adam" "sgd" "adagrad" "rowwise_adagrad") +INCLUDE_OPTIM=("True" "False") +SCORE_STRATEGY=("timestamp" "lfu" "step") +INCLUDE_COUNTER=("True" "False") -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup True --batch_size 1024 || exit 1 +for num_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + for include_optim in ${INCLUDE_OPTIM[@]}; do + for include_counter in ${INCLUDE_COUNTER[@]}; do + for score_strategy in ${SCORE_STRATEGY[@]}; do + echo "num_gpus: $num_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, include_counter: $include_counter, score_strategy: $score_strategy" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + ./test/unit_tests/test_embedding_dump_load.py \ + --optimizer-type ${optimizer_type} \ + --score-strategy ${score_strategy} \ + --mode "dump" \ + --optim ${include_optim} \ + --counter ${include_counter} \ + --save-path "debug_weight_${optimizer_type}_${num_gpus}_${include_optim}_${include_counter}_${score_strategy}" \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim 16 || exit 1 + done + done + done + done +done -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup False --batch_size 1024 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 129 --use_index_dedup True --batch_size 1024 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 129 --use_index_dedup False --batch_size 1024 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup True --batch_size 1023 --multi_hot_sizes=20,1,101,49 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup False --batch_size 1023 --multi_hot_sizes=20,1,101,49 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup True --batch_size 1023 --multi_hot_sizes=20,49,101,1 || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup False --batch_size 1023 --multi_hot_sizes=20,1,101,49 --score_type="step" || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup True --batch_size 1023 --multi_hot_sizes=20,49,101,1 --score_type="step" || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup False --batch_size 1023 --multi_hot_sizes=20,1,101,49 --score_type="custimized" || exit 1 - -CUDA_VISIBLE_DEVICES=0,1 torchrun \ - --nnodes 1 \ - --nproc_per_node 2 \ - ./test/unit_tests/test_embedding_dump_load.py --print_sharding_plan --optimizer_type "adam" --embedding_dim 15 --use_index_dedup True --batch_size 1023 --multi_hot_sizes=20,49,101,1 --score_type="custimized" || exit 1 +for num_load_gpus in ${NUM_GPUS[@]}; do + for num_dump_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + for include_optim in ${INCLUDE_OPTIM[@]}; do + for include_counter in ${INCLUDE_COUNTER[@]}; do + for score_strategy in ${SCORE_STRATEGY[@]}; do + echo "num_load_gpus: $num_load_gpus, num_dump_gpus: $num_dump_gpus, optimizer_type: $optimizer_type, include_optim: $include_optim, include_counter: $include_counter, score_strategy: $score_strategy" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_load_gpus \ + ./test/unit_tests/test_embedding_dump_load.py \ + --optimizer-type ${optimizer_type} \ + --score-strategy ${score_strategy} \ + --mode "load" \ + --optim ${include_optim} \ + --counter ${include_counter} \ + --save-path "debug_weight_${optimizer_type}_${num_dump_gpus}_${include_optim}_${include_counter}_${score_strategy}" \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim 16 || exit 1 + done + done + done + done + done +done \ No newline at end of file diff --git a/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py new file mode 100644 index 000000000..ff3a375d4 --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.py @@ -0,0 +1,324 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from collections import defaultdict +from typing import Dict, List, Optional, Tuple + +import click +import torch +import torch.distributed as dist +import torch.nn as nn +from dynamicemb import DynamicEmbScoreStrategy +from dynamicemb.dump_load import find_sharded_modules, get_dynamic_emb_module +from dynamicemb.key_value_table import batched_export_keys_values +from test_embedding_dump_load import create_model, get_optimizer_kwargs, idx_to_name +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def generate_deterministic_sparse_features_with_frequency_tracking( + num_embedding_collections: int, + num_embeddings: List[int], + multi_hot_sizes: List[int], + rank: int, + world_size: int, + batch_size: int, + num_iterations: int, + seed: int = 42, + caching: bool = False, +) -> Tuple[List[KeyedJaggedTensor], Dict[str, Dict[int, int]]]: + """ + Generate deterministic sparse features and track frequency for each embedding table. + + Args: + caching: If True, generate more unique keys to trigger cache eviction + + Returns: + kjts: List of KeyedJaggedTensor for each iteration + table_frequency_counters: Dict mapping table_name -> {key: frequency} + """ + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + batch_size_per_rank = batch_size // world_size + kjts = [] + table_frequency_counters = {} + + # Initialize frequency counters for each table + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + _, embedding_name = idx_to_name(embedding_collection_id, embedding_id) + table_frequency_counters[embedding_name] = defaultdict(int) + + for iteration in range(num_iterations): + cur_indices = [] + cur_lengths = [] + keys = [] + + for embedding_collection_id in range(num_embedding_collections): + for embedding_id, num_embedding in enumerate(num_embeddings): + feature_name, embedding_name = idx_to_name( + embedding_collection_id, embedding_id + ) + + for sample_id in range(batch_size): + hotness = random.randint( + 1, multi_hot_sizes[embedding_collection_id] + ) + if caching: + # In caching mode, use wider range to trigger eviction + # Use 80% of num_embedding to generate enough unique keys + max_key = int(num_embedding * 0.8) - 1 + else: + # In storage-only mode, limit to 100 keys for more duplicates + max_key = min(num_embedding - 1, 100) + indices = [random.randint(0, max_key) for _ in range(hotness)] + # Track frequency for all generated indices + for idx in indices: + table_frequency_counters[embedding_name][idx] += 1 + + if sample_id // batch_size_per_rank == rank: + cur_indices.extend(indices) + cur_lengths.append(hotness) + + keys.append(feature_name) + + kjts.append( + KeyedJaggedTensor.from_lengths_sync( + keys=keys, + values=torch.tensor(cur_indices, dtype=torch.int64).cuda(), + lengths=torch.tensor(cur_lengths, dtype=torch.int64).cuda(), + ) + ) + + # Convert defaultdicts to regular dicts + for table_name in table_frequency_counters: + table_frequency_counters[table_name] = dict( + table_frequency_counters[table_name] + ) + + return kjts, table_frequency_counters + + +def local_DynamicEmbDump( + model: nn.Module, + table_names: Optional[Dict[str, List[str]]] = None, + optim: Optional[bool] = False, + pg: Optional[dist.ProcessGroup] = None, +) -> Dict[str, Dict[int, int]]: + """ + Load scores from dynamic embedding tables directly (without disk I/O). + + Returns: + Dict[table_name, Dict[key, score]]: Scores organized by table name + """ + + if torch.cuda.is_available(): + torch.cuda.synchronize() + dist.barrier(group=pg, device_ids=[torch.cuda.current_device()]) + device = torch.device(f"cuda:{torch.cuda.current_device()}") + + batch_size = 65536 + + all_table_scores = {} + + for _, collection_name, sharded_module in find_sharded_modules(model, ""): + dynamic_emb_modules = get_dynamic_emb_module(sharded_module) + + for dynamic_emb_module in dynamic_emb_modules: + dynamic_emb_module.flush() + + for table_name, table in zip( + dynamic_emb_module.table_names, dynamic_emb_module.tables + ): + if table_names is not None and table_name not in set( + table_names[collection_name] + ): + continue + + table_key_scores = {} + + for keys, _, _, scores in batched_export_keys_values( + table.table, device, batch_size + ): + for key, score in zip(keys, scores): + table_key_scores[int(key)] = int(score) + + all_table_scores[table_name] = table_key_scores + + if torch.cuda.is_available(): + torch.cuda.synchronize() + + dist.barrier(group=pg, device_ids=[torch.cuda.current_device()]) + + return all_table_scores + + +def validate_lfu_scores( + expected_frequencies: Dict[str, Dict[int, int]], + actual_scores: Dict[str, Dict[int, int]], + tolerance: float = 0.0, +): + """ + Validate that actual scores match expected frequencies for LFU strategy. + + Returns: + (is_valid, error_message) + """ + for table_name in expected_frequencies: + if table_name not in actual_scores: + assert ( + table_name in actual_scores + ), f"Table {table_name} missing from actual scores" + + expected = expected_frequencies[table_name] + actual = actual_scores[table_name] + + for key in actual.keys(): + exp_freq = expected[key] + act_score = actual[key] + assert ( + exp_freq == act_score + ), f"Table {table_name}, Key {key}: Expected frequency: {exp_freq}, Actual score: {act_score}" + + +@click.command() +@click.option("--num-embedding-collections", type=int, default=1) +@click.option("--num-embeddings", type=str, default="1000") +@click.option("--multi-hot-sizes", type=str, default="3") +@click.option("--embedding-dim", type=int, default=16) +@click.option( + "--optimizer-type", + type=click.Choice(["sgd", "adam", "adagrad", "rowwise_adagrad"]), + default="sgd", +) +@click.option("--batch-size", type=int, default=16) +@click.option("--num-iterations", type=int, default=3) +@click.option("--tolerance", type=float, default=0.0) +@click.option("--caching", is_flag=True, help="Enable cache + storage architecture") +@click.option( + "--cache-capacity-ratio", + type=float, + default=0.5, + help="Cache capacity as ratio of storage capacity (only used when --caching is enabled)", +) +def test_lfu_score_validation( + num_embedding_collections: int, + num_embeddings: str, + multi_hot_sizes: str, + embedding_dim: int, + optimizer_type: str, + batch_size: int, + num_iterations: int, + tolerance: float, + caching: bool, + cache_capacity_ratio: float, +): + """Test LFU score correctness by comparing with naive frequency counting. + + This test supports two modes: + - Storage-only (default): Tests LFU scores in storage directly + - Cache + Storage (--caching): Tests LFU score propagation through cache to storage + """ + + num_embeddings = [int(v) for v in num_embeddings.split(",")] + multi_hot_sizes = [int(v) for v in multi_hot_sizes.split(",")] + use_index_dedup = True + + if not caching: + for num_embedding, multi_hot_size in zip(num_embeddings, multi_hot_sizes): + if batch_size * num_iterations * multi_hot_size > num_embedding: + raise ValueError( + "batch_size * num_iterations * multi_hot_size > num_embedding, " + "this may lead to eviction of dynamicemb and cause test fail" + ) + + print(f"Configuration:") + print(f" - Embedding collections: {num_embedding_collections}") + print(f" - Num embeddings: {num_embeddings}") + print(f" - Multi-hot sizes: {multi_hot_sizes}") + print(f" - Embedding dim: {embedding_dim}") + print(f" - Optimizer: {optimizer_type}") + print(f" - Batch size: {batch_size}") + print(f" - Iterations: {num_iterations}") + print(f" - Tolerance: {tolerance}") + print(f" - Use index dedup: {use_index_dedup}") + if caching: + print(f" - Caching: ENABLED ✓") + print(f" - Cache capacity ratio: {cache_capacity_ratio}") + else: + print(f" - Caching: DISABLED") + + # Create model + optimizer_kwargs = get_optimizer_kwargs(optimizer_type) + model = create_model( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + embedding_dim=embedding_dim, + optimizer_kwargs=optimizer_kwargs, + score_strategy=DynamicEmbScoreStrategy.LFU, + use_index_dedup=use_index_dedup, + caching=caching, + cache_capacity_ratio=cache_capacity_ratio if caching else 0.1, + ) + + # Generate features with frequency tracking + ( + kjts, + expected_frequencies, + ) = generate_deterministic_sparse_features_with_frequency_tracking( + num_embedding_collections=num_embedding_collections, + num_embeddings=num_embeddings, + multi_hot_sizes=multi_hot_sizes, + rank=dist.get_rank(), + world_size=dist.get_world_size(), + batch_size=batch_size, + num_iterations=num_iterations, + caching=caching, + ) + + # Run forward passes to populate frequency information + if caching: + print(f"\nRunning {num_iterations} iterations with cache enabled...") + else: + print(f"\nRunning {num_iterations} iterations...") + + for iteration, kjt in enumerate(kjts): + ret = model(kjt) + torch.cuda.synchronize() + loss = ret.sum() * dist.get_world_size() + loss.backward() + torch.cuda.synchronize() + + # Extract actual scores + actual_scores = local_DynamicEmbDump(model, optim=True) + + torch.cuda.synchronize() + + # Validate scores + validate_lfu_scores(expected_frequencies, actual_scores, tolerance) + + +if __name__ == "__main__": + LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(LOCAL_RANK) + + dist.init_process_group(backend="nccl") + test_lfu_score_validation() + dist.barrier() + dist.destroy_process_group() diff --git a/corelib/dynamicemb/test/unit_tests/test_lfu_scores.sh b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.sh new file mode 100644 index 000000000..253f8797b --- /dev/null +++ b/corelib/dynamicemb/test/unit_tests/test_lfu_scores.sh @@ -0,0 +1,132 @@ +#!/bin/bash +set -e + +# Test configurations +NUM_EMBEDDING_COLLECTIONS=2 +NUM_EMBEDDINGS=10000,10000,10000,10000 +MULTI_HOT_SIZES=5,5,5,5 +EMBEDDING_DIM=16 +NUM_GPUS=(1 4) +OPTIMIZER_TYPE=("sgd" "adam" "adagrad" "rowwise_adagrad") +BATCH_SIZE=32 +NUM_ITERATIONS=10 +TOLERANCE=0.0 + +# Cache configurations +CACHING_MODES=("False" "True") +CACHE_CAPACITY_RATIO=0.3 # 30% cache capacity to trigger evictions + + +for num_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + echo "" + echo "----------------------------------------" + echo "Test: Storage-Only | GPUs: $num_gpus | Optimizer: $optimizer_type" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + ./test/unit_tests/test_lfu_scores.py \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $BATCH_SIZE \ + --num-iterations $NUM_ITERATIONS \ + --tolerance $TOLERANCE || exit 1 + done +done + + +for num_gpus in ${NUM_GPUS[@]}; do + for optimizer_type in ${OPTIMIZER_TYPE[@]}; do + echo "" + echo "----------------------------------------" + echo "Test: Cache+Storage | GPUs: $num_gpus | Optimizer: $optimizer_type | Cache Ratio: $CACHE_CAPACITY_RATIO" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node $num_gpus \ + ./test/unit_tests/test_lfu_scores.py \ + --num-embedding-collections $NUM_EMBEDDING_COLLECTIONS \ + --num-embeddings $NUM_EMBEDDINGS \ + --multi-hot-sizes $MULTI_HOT_SIZES \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $BATCH_SIZE \ + --num-iterations $NUM_ITERATIONS \ + --tolerance $TOLERANCE \ + --caching \ + --cache-capacity-ratio $CACHE_CAPACITY_RATIO || exit 1 + done +done + + + +# High-frequency test: more iterations to test frequency accumulation + +HIGH_FREQ_ITERATIONS=50 +for caching_mode in "without-cache" "with-cache"; do + echo "" + echo "----------------------------------------" + echo "Test: High Frequency ($HIGH_FREQ_ITERATIONS iters) | Mode: $caching_mode" + echo "----------------------------------------" + if [ "$caching_mode" = "without-cache" ]; then + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_lfu_scores.py \ + --num-embedding-collections 1 \ + --num-embeddings 5000 \ + --multi-hot-sizes 3 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type sgd \ + --batch-size 16 \ + --num-iterations $HIGH_FREQ_ITERATIONS \ + --tolerance $TOLERANCE || exit 1 + else + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_lfu_scores.py \ + --num-embedding-collections 1 \ + --num-embeddings 5000 \ + --multi-hot-sizes 3 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type sgd \ + --batch-size 16 \ + --num-iterations $HIGH_FREQ_ITERATIONS \ + --tolerance $TOLERANCE \ + --caching \ + --cache-capacity-ratio 0.4 || exit 1 + fi +done + + + +EVICTION_CACHE_RATIO_1=0.08 # 2% - Very small cache +EVICTION_BATCH_SIZE_1=64 # Large batch size +EVICTION_ITERATIONS_1=25 # Many iterations + +for optimizer_type in "sgd" "adam"; do + echo "" + echo "----------------------------------------" + echo "Test: Ultra-small Cache | Optimizer: $optimizer_type | Cache Ratio: $EVICTION_CACHE_RATIO_1" + echo "----------------------------------------" + torchrun \ + --nnodes 1 \ + --nproc_per_node 1 \ + ./test/unit_tests/test_lfu_scores.py \ + --num-embedding-collections 1 \ + --num-embeddings 10000 \ + --multi-hot-sizes 5 \ + --embedding-dim $EMBEDDING_DIM \ + --optimizer-type ${optimizer_type} \ + --batch-size $EVICTION_BATCH_SIZE_1 \ + --num-iterations $EVICTION_ITERATIONS_1 \ + --tolerance $TOLERANCE \ + --caching \ + --cache-capacity-ratio $EVICTION_CACHE_RATIO_1 || exit 1 +done + diff --git a/corelib/dynamicemb/test/unit_tests/test_pooled_embedding.sh b/corelib/dynamicemb/test/unit_tests/test_pooled_embedding.sh index d7702bbeb..67e87f9ca 100644 --- a/corelib/dynamicemb/test/unit_tests/test_pooled_embedding.sh +++ b/corelib/dynamicemb/test/unit_tests/test_pooled_embedding.sh @@ -33,12 +33,12 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun \ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ --nnodes 1 \ - --nproc_per_node 4 \ + --nproc_per_node 2 \ ./test/unit_tests/test_pooled_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --batch_size=1024 --num_embeddings_per_feature=8388608,4194304,524288,1048576 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \ --nnodes 1 \ - --nproc_per_node 8 \ + --nproc_per_node 2 \ ./test/unit_tests/test_pooled_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --batch_size=1024 --num_embeddings_per_feature=8388608,4194304,524288,1048576 diff --git a/corelib/dynamicemb/test/unit_tests/test_sequence_embedding.sh b/corelib/dynamicemb/test/unit_tests/test_sequence_embedding.sh index 13ecc2ba6..4cf0a831d 100644 --- a/corelib/dynamicemb/test/unit_tests/test_sequence_embedding.sh +++ b/corelib/dynamicemb/test/unit_tests/test_sequence_embedding.sh @@ -50,12 +50,12 @@ CUDA_VISIBLE_DEVICES=0,1 torchrun \ CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun \ --nnodes 1 \ - --nproc_per_node 4 \ + --nproc_per_node 2 \ ./test/unit_tests/test_sequence_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup True --batch_size 1024 --num_embeddings_per_feature=8388608,4194304,524288,1048576 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun \ --nnodes 1 \ - --nproc_per_node 8 \ + --nproc_per_node 2 \ ./test/unit_tests/test_sequence_embedding_fw.py --print_sharding_plan --optimizer_type "adam" --use_index_dedup True --batch_size 1024 --num_embeddings_per_feature=8388608,4194304,524288,1048576 # Test sequence embedding's backward on a single GPU with different ["use_index_dedup", "dim", "batch_size", "multi_hot_sizes"] diff --git a/corelib/dynamicemb/test/unit_tests/test_twin_module.sh b/corelib/dynamicemb/test/unit_tests/test_twin_module.sh index b632fb8fc..f381565c6 100644 --- a/corelib/dynamicemb/test/unit_tests/test_twin_module.sh +++ b/corelib/dynamicemb/test/unit_tests/test_twin_module.sh @@ -1,4 +1,4 @@ set -e torchrun --nproc_per_node=1 -m pytest -svv test/unit_tests/test_twin_module.py -torchrun --nproc_per_node=4 -m pytest -svv test/unit_tests/test_twin_module.py +torchrun --nproc_per_node=2 -m pytest -svv test/unit_tests/test_twin_module.py diff --git a/third_party/HierarchicalKV b/third_party/HierarchicalKV index 012237dd6..0ec9aa3ca 160000 --- a/third_party/HierarchicalKV +++ b/third_party/HierarchicalKV @@ -1 +1 @@ -Subproject commit 012237dd64647cc94797e8270cacf11fe7032fd7 +Subproject commit 0ec9aa3ca3e8164f902aefd68bef20d729f80530