-
Notifications
You must be signed in to change notification settings - Fork 95
Simplify to_constant_volume_bins #562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,7 @@ | |
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def to_constant_volume_bins( # noqa: C901, PLR0915 | ||
| def to_constant_volume_bins( # noqa: C901 | ||
| items: dict[int, float] | list[Any], | ||
| max_volume: float, | ||
| *, | ||
|
|
@@ -57,8 +57,9 @@ def to_constant_volume_bins( # noqa: C901, PLR0915 | |
| Args: | ||
| items (dict[int, float] | list[float] | list[tuple]): Items to distribute, | ||
| provided as either: | ||
| - Dictionary with numeric weights as values | ||
| - List of numeric weights | ||
| - Dictionary with numeric weights as values: maps system_idx -> weight | ||
| - List of numeric weights. Useful when you care about how many bins | ||
| are needed since the index isn't tracked. | ||
| - List of tuples containing weights (requires weight_pos or key) | ||
| max_volume (float): Maximum allowed weight sum per bin. | ||
| weight_pos (int | None): For tuple lists, index of weight in each tuple. | ||
|
|
@@ -78,132 +79,63 @@ def to_constant_volume_bins( # noqa: C901, PLR0915 | |
| - List of lists of tuples if input was a list of tuples | ||
|
|
||
| Raises: | ||
| TypeError: If input is not iterable. | ||
| ValueError: If weight_pos or key is not provided for tuple list input, | ||
| or if lower_bound >= upper_bound. | ||
| """ | ||
|
|
||
| def _get_bins[T](lst: list[T], ndx: list[int]) -> list[T]: | ||
| return [lst[n] for n in ndx] | ||
|
|
||
| def _argmax_bins(lst: list[float]) -> int: | ||
| return max(range(len(lst)), key=lambda idx: lst[idx]) | ||
|
|
||
| def _rev_argsort_bins(lst: list[float]) -> list[int]: | ||
| return sorted(range(len(lst)), key=lambda i: -lst[i]) | ||
| if lower_bound is not None and upper_bound is not None and lower_bound >= upper_bound: | ||
| raise ValueError("lower_bound is greater or equal to upper_bound") | ||
|
|
||
| if not hasattr(items, "__len__"): | ||
| raise TypeError("items must be iterable") | ||
| if len(items) == 0: | ||
| return [] | ||
|
|
||
| if not isinstance(items, dict) and len(items) > 0 and hasattr(items[0], "__len__"): | ||
| if weight_pos is not None: | ||
| key = lambda x: x[weight_pos] # noqa: E731 | ||
| if key is None: | ||
| raise ValueError("Must provide weight_pos or key for tuple list") | ||
|
|
||
| if not isinstance(items, dict) and key: | ||
| new_dict = dict(enumerate(items)) | ||
| items = {idx: key(val) for idx, val in enumerate(items)} | ||
| is_tuple_list = True | ||
| # Normalize input to (weight, payload) entries. The payload is whatever gets | ||
| # placed in the output bin: the dict key, the original tuple, or the weight. | ||
| is_dict = isinstance(items, dict) | ||
| if is_dict: | ||
| entries = [(weight, k) for k, weight in items.items()] # ty: ignore[unresolved-attribute] | ||
| # list of objects: dispatch on how to extract the weight from each item | ||
| elif weight_pos is not None: | ||
| # weight lives at a fixed tuple/list position; payload is the original item | ||
| entries = [(item[weight_pos], item) for item in items] # ty: ignore[not-subscriptable] | ||
| elif key is not None: | ||
| # custom extractor for arbitrary item types; payload is the original item | ||
| entries = [(key(item), item) for item in items] | ||
| elif isinstance(items[0], (tuple, list)): | ||
| # structured items but caller didn't say how to extract a weight | ||
| raise ValueError("Must provide weight_pos or key for tuple list") | ||
| else: | ||
| is_tuple_list = False | ||
|
|
||
| if isinstance(items, dict): | ||
| # get keys and values (weights) | ||
| keys = list(items) | ||
| vals = list(items.values()) | ||
|
|
||
| # sort weights decreasingly | ||
| n_dcs = _rev_argsort_bins(vals) | ||
|
|
||
| weights = _get_bins(vals, n_dcs) | ||
| keys = _get_bins(keys, n_dcs) | ||
|
|
||
| bins = [[]] if is_tuple_list else [{}] | ||
| else: | ||
| weights = sorted(items, key=lambda x: -x) | ||
| bins = [[]] | ||
|
|
||
| # find the valid indices | ||
| if lower_bound is not None and upper_bound is not None and lower_bound < upper_bound: | ||
| valid_ndcs = filter( | ||
| lambda i: lower_bound < weights[i] < upper_bound, range(len(weights)) | ||
| ) | ||
| elif lower_bound is not None: | ||
| valid_ndcs = filter(lambda i: lower_bound < weights[i], range(len(weights))) | ||
| elif upper_bound is not None: | ||
| valid_ndcs = filter(lambda i: weights[i] < upper_bound, range(len(weights))) | ||
| elif lower_bound is None and upper_bound is None: | ||
| valid_ndcs = range(len(weights)) | ||
| elif lower_bound >= upper_bound: | ||
| raise ValueError("lower_bound is greater or equal to upper_bound") | ||
|
|
||
| valid_ndcs = list(valid_ndcs) | ||
|
|
||
| weights = _get_bins(weights, valid_ndcs) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above logic to get all of the valid indices is simplified to: |
||
|
|
||
| if isinstance(items, dict): | ||
| keys = _get_bins(keys, valid_ndcs) | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is no longer needed since we operate on the new tuple objects, which automatically carries the keys in the entries list |
||
|
|
||
| # prepare array containing the current weight of the bins | ||
| weight_sum = [0.0] | ||
|
|
||
| # iterate through the weight list, starting with heaviest | ||
| for item, weight in enumerate(weights): | ||
| if isinstance(items, dict): | ||
| item_key = keys[item] | ||
|
|
||
| # find candidate bins where the weight might fit | ||
| candidate_bins = list( | ||
| filter(lambda i: weight_sum[i] + weight <= max_volume, range(len(weight_sum))) | ||
| ) | ||
|
|
||
| if candidate_bins: # if there are candidates where it fits | ||
| # find the fullest bin where this item fits and assign it | ||
| candidate_index = _argmax_bins(_get_bins(weight_sum, candidate_bins)) | ||
| b = candidate_bins[candidate_index] | ||
|
|
||
| # if this weight doesn't fit in any existent bin | ||
| elif item > 0: | ||
| # note! if this is the very first item then there is already an | ||
| # empty bin open so we don't need to open another one. | ||
|
|
||
| # open a new bin | ||
| b = len(weight_sum) | ||
| weight_sum.append(0.0) | ||
| if isinstance(items, dict): | ||
| bins.append([] if is_tuple_list else {}) | ||
| else: | ||
| bins.append([]) | ||
|
|
||
| # if we are at the very first item, use the empty bin already open | ||
| else: | ||
| b = 0 | ||
|
|
||
| # put it in | ||
| if isinstance(items, dict): | ||
| bin_ = bins[b] | ||
| if is_tuple_list: | ||
| if not isinstance(bin_, list): | ||
| raise TypeError("bins contain lists when tuple-list mode is used") | ||
| bin_.append(item_key) | ||
| elif isinstance(bin_, dict): | ||
| bin_[item_key] = weight | ||
| else: | ||
| bin_ = bins[b] | ||
| if not isinstance(bin_, list): | ||
| raise TypeError("bins contain lists when items is not dict") | ||
| bin_.append(weight) | ||
|
|
||
| # increase weight sum of the bin and continue with | ||
| # next item | ||
| weight_sum[b] += weight | ||
|
|
||
| if not is_tuple_list: | ||
| return bins | ||
| return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in bins] | ||
| # plain numeric weights; the item is its own payload | ||
| entries = [(weight, weight) for weight in items] | ||
|
|
||
| if lower_bound is not None: | ||
| entries = [e for e in entries if e[0] > lower_bound] | ||
| if upper_bound is not None: | ||
| entries = [e for e in entries if e[0] < upper_bound] | ||
|
|
||
| # Pack heaviest first, opening a new bin only when nothing fits. | ||
| entries.sort(key=lambda e: -e[0]) | ||
| bin_entries: list[list[tuple[float, Any]]] = [] | ||
| bin_sums: list[float] = [] | ||
| for weight, payload in entries: | ||
| # get all bin indices that can fit this payload | ||
| candidate_bin_indices = [ | ||
| i for i, s in enumerate(bin_sums) if s + weight <= max_volume | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. one thing I noticed that was missing in the original implementation (I didn't add it here). but I think we need to have an assert to ensure that all weights are <= max_volume? bc if there contains an item that exceeds max_volume, we can never fit it into a bin, even a new bin
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe we add this line around 120? # warn on oversized items — they'll still be placed in a bin by themselves,
# but that bin will exceed max_volume
for weight, payload in entries:
if weight > max_volume:
logger.warning(
"item %r has weight %s > max_volume %s; placing in its own bin anyway",
payload,
weight,
max_volume,
)Right now, we allow items to overflow the max bin size. some tests fail if we add this line and throw an exception
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually wanted this to be an exception, but I can see cases where people might want the weight to exceed the max volume so maybe it's not necessary to add this in. |
||
| ] | ||
| if candidate_bin_indices: | ||
| # get the idx of the most-full bin that has enough space | ||
| b = max(candidate_bin_indices, key=lambda i: bin_sums[i]) | ||
| else: # no bin has enough space. Create a new bin | ||
| b = len(bin_sums) | ||
| bin_entries.append([]) | ||
| bin_sums.append(0.0) | ||
| bin_entries[b].append((weight, payload)) | ||
| bin_sums[b] += weight | ||
|
|
||
| if is_dict: | ||
| return [{payload: weight for weight, payload in bin_} for bin_ in bin_entries] | ||
| return [[payload for _, payload in bin_] for bin_ in bin_entries] | ||
|
|
||
|
|
||
| def measure_model_memory_forward(state: SimState, model: ModelInterface) -> float: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all of the above logic to get this reverse weights list is simplified to
entries.sort(key=lambda e: -e[0])below