From 0341adbaef4508cc10d5913d199b4993f9abda55 Mon Sep 17 00:00:00 2001 From: Curtis Chong Date: Thu, 14 May 2026 18:03:34 -0700 Subject: [PATCH] add docs for autobatching params clean fn cleanup ruff add back len check manual comments backup better logic add max volume check remove fix ruff --- torch_sim/autobatching.py | 170 ++++++++++++-------------------------- 1 file changed, 51 insertions(+), 119 deletions(-) diff --git a/torch_sim/autobatching.py b/torch_sim/autobatching.py index 9671ed97..39eff6db 100644 --- a/torch_sim/autobatching.py +++ b/torch_sim/autobatching.py @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) -def to_constant_volume_bins( # noqa: C901, PLR0915 +def to_constant_volume_bins( # noqa: C901 items: dict[int, float] | list[Any], max_volume: float, *, @@ -57,8 +57,9 @@ def to_constant_volume_bins( # noqa: C901, PLR0915 Args: items (dict[int, float] | list[float] | list[tuple]): Items to distribute, provided as either: - - Dictionary with numeric weights as values - - List of numeric weights + - Dictionary with numeric weights as values: maps system_idx -> weight + - List of numeric weights. Useful when you care about how many bins + are needed since the index isn't tracked. - List of tuples containing weights (requires weight_pos or key) max_volume (float): Maximum allowed weight sum per bin. weight_pos (int | None): For tuple lists, index of weight in each tuple. @@ -78,132 +79,63 @@ def to_constant_volume_bins( # noqa: C901, PLR0915 - List of lists of tuples if input was a list of tuples Raises: - TypeError: If input is not iterable. ValueError: If weight_pos or key is not provided for tuple list input, or if lower_bound >= upper_bound. """ - - def _get_bins[T](lst: list[T], ndx: list[int]) -> list[T]: - return [lst[n] for n in ndx] - - def _argmax_bins(lst: list[float]) -> int: - return max(range(len(lst)), key=lambda idx: lst[idx]) - - def _rev_argsort_bins(lst: list[float]) -> list[int]: - return sorted(range(len(lst)), key=lambda i: -lst[i]) + if lower_bound is not None and upper_bound is not None and lower_bound >= upper_bound: + raise ValueError("lower_bound is greater or equal to upper_bound") if not hasattr(items, "__len__"): raise TypeError("items must be iterable") if len(items) == 0: return [] - if not isinstance(items, dict) and len(items) > 0 and hasattr(items[0], "__len__"): - if weight_pos is not None: - key = lambda x: x[weight_pos] # noqa: E731 - if key is None: - raise ValueError("Must provide weight_pos or key for tuple list") - - if not isinstance(items, dict) and key: - new_dict = dict(enumerate(items)) - items = {idx: key(val) for idx, val in enumerate(items)} - is_tuple_list = True + # Normalize input to (weight, payload) entries. The payload is whatever gets + # placed in the output bin: the dict key, the original tuple, or the weight. + is_dict = isinstance(items, dict) + if is_dict: + entries = [(weight, k) for k, weight in items.items()] # ty: ignore[unresolved-attribute] + # list of objects: dispatch on how to extract the weight from each item + elif weight_pos is not None: + # weight lives at a fixed tuple/list position; payload is the original item + entries = [(item[weight_pos], item) for item in items] # ty: ignore[not-subscriptable] + elif key is not None: + # custom extractor for arbitrary item types; payload is the original item + entries = [(key(item), item) for item in items] + elif isinstance(items[0], (tuple, list)): + # structured items but caller didn't say how to extract a weight + raise ValueError("Must provide weight_pos or key for tuple list") else: - is_tuple_list = False - - if isinstance(items, dict): - # get keys and values (weights) - keys = list(items) - vals = list(items.values()) - - # sort weights decreasingly - n_dcs = _rev_argsort_bins(vals) - - weights = _get_bins(vals, n_dcs) - keys = _get_bins(keys, n_dcs) - - bins = [[]] if is_tuple_list else [{}] - else: - weights = sorted(items, key=lambda x: -x) - bins = [[]] - - # find the valid indices - if lower_bound is not None and upper_bound is not None and lower_bound < upper_bound: - valid_ndcs = filter( - lambda i: lower_bound < weights[i] < upper_bound, range(len(weights)) - ) - elif lower_bound is not None: - valid_ndcs = filter(lambda i: lower_bound < weights[i], range(len(weights))) - elif upper_bound is not None: - valid_ndcs = filter(lambda i: weights[i] < upper_bound, range(len(weights))) - elif lower_bound is None and upper_bound is None: - valid_ndcs = range(len(weights)) - elif lower_bound >= upper_bound: - raise ValueError("lower_bound is greater or equal to upper_bound") - - valid_ndcs = list(valid_ndcs) - - weights = _get_bins(weights, valid_ndcs) - - if isinstance(items, dict): - keys = _get_bins(keys, valid_ndcs) - - # prepare array containing the current weight of the bins - weight_sum = [0.0] - - # iterate through the weight list, starting with heaviest - for item, weight in enumerate(weights): - if isinstance(items, dict): - item_key = keys[item] - - # find candidate bins where the weight might fit - candidate_bins = list( - filter(lambda i: weight_sum[i] + weight <= max_volume, range(len(weight_sum))) - ) - - if candidate_bins: # if there are candidates where it fits - # find the fullest bin where this item fits and assign it - candidate_index = _argmax_bins(_get_bins(weight_sum, candidate_bins)) - b = candidate_bins[candidate_index] - - # if this weight doesn't fit in any existent bin - elif item > 0: - # note! if this is the very first item then there is already an - # empty bin open so we don't need to open another one. - - # open a new bin - b = len(weight_sum) - weight_sum.append(0.0) - if isinstance(items, dict): - bins.append([] if is_tuple_list else {}) - else: - bins.append([]) - - # if we are at the very first item, use the empty bin already open - else: - b = 0 - - # put it in - if isinstance(items, dict): - bin_ = bins[b] - if is_tuple_list: - if not isinstance(bin_, list): - raise TypeError("bins contain lists when tuple-list mode is used") - bin_.append(item_key) - elif isinstance(bin_, dict): - bin_[item_key] = weight - else: - bin_ = bins[b] - if not isinstance(bin_, list): - raise TypeError("bins contain lists when items is not dict") - bin_.append(weight) - - # increase weight sum of the bin and continue with - # next item - weight_sum[b] += weight - - if not is_tuple_list: - return bins - return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in bins] + # plain numeric weights; the item is its own payload + entries = [(weight, weight) for weight in items] + + if lower_bound is not None: + entries = [e for e in entries if e[0] > lower_bound] + if upper_bound is not None: + entries = [e for e in entries if e[0] < upper_bound] + + # Pack heaviest first, opening a new bin only when nothing fits. + entries.sort(key=lambda e: -e[0]) + bin_entries: list[list[tuple[float, Any]]] = [] + bin_sums: list[float] = [] + for weight, payload in entries: + # get all bin indices that can fit this payload + candidate_bin_indices = [ + i for i, s in enumerate(bin_sums) if s + weight <= max_volume + ] + if candidate_bin_indices: + # get the idx of the most-full bin that has enough space + b = max(candidate_bin_indices, key=lambda i: bin_sums[i]) + else: # no bin has enough space. Create a new bin + b = len(bin_sums) + bin_entries.append([]) + bin_sums.append(0.0) + bin_entries[b].append((weight, payload)) + bin_sums[b] += weight + + if is_dict: + return [{payload: weight for weight, payload in bin_} for bin_ in bin_entries] + return [[payload for _, payload in bin_] for bin_ in bin_entries] def measure_model_memory_forward(state: SimState, model: ModelInterface) -> float: