Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 51 additions & 119 deletions torch_sim/autobatching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
logger = logging.getLogger(__name__)


def to_constant_volume_bins( # noqa: C901, PLR0915
def to_constant_volume_bins( # noqa: C901
items: dict[int, float] | list[Any],
max_volume: float,
*,
Expand All @@ -57,8 +57,9 @@ def to_constant_volume_bins( # noqa: C901, PLR0915
Args:
items (dict[int, float] | list[float] | list[tuple]): Items to distribute,
provided as either:
- Dictionary with numeric weights as values
- List of numeric weights
- Dictionary with numeric weights as values: maps system_idx -> weight
- List of numeric weights. Useful when you care about how many bins
are needed since the index isn't tracked.
- List of tuples containing weights (requires weight_pos or key)
max_volume (float): Maximum allowed weight sum per bin.
weight_pos (int | None): For tuple lists, index of weight in each tuple.
Expand All @@ -78,132 +79,63 @@ def to_constant_volume_bins( # noqa: C901, PLR0915
- List of lists of tuples if input was a list of tuples

Raises:
TypeError: If input is not iterable.
ValueError: If weight_pos or key is not provided for tuple list input,
or if lower_bound >= upper_bound.
"""

def _get_bins[T](lst: list[T], ndx: list[int]) -> list[T]:
return [lst[n] for n in ndx]

def _argmax_bins(lst: list[float]) -> int:
return max(range(len(lst)), key=lambda idx: lst[idx])

def _rev_argsort_bins(lst: list[float]) -> list[int]:
return sorted(range(len(lst)), key=lambda i: -lst[i])
if lower_bound is not None and upper_bound is not None and lower_bound >= upper_bound:
raise ValueError("lower_bound is greater or equal to upper_bound")

if not hasattr(items, "__len__"):
raise TypeError("items must be iterable")
if len(items) == 0:
return []

if not isinstance(items, dict) and len(items) > 0 and hasattr(items[0], "__len__"):
if weight_pos is not None:
key = lambda x: x[weight_pos] # noqa: E731
if key is None:
raise ValueError("Must provide weight_pos or key for tuple list")

if not isinstance(items, dict) and key:
new_dict = dict(enumerate(items))
items = {idx: key(val) for idx, val in enumerate(items)}
is_tuple_list = True
# Normalize input to (weight, payload) entries. The payload is whatever gets
# placed in the output bin: the dict key, the original tuple, or the weight.
is_dict = isinstance(items, dict)
if is_dict:
entries = [(weight, k) for k, weight in items.items()] # ty: ignore[unresolved-attribute]
# list of objects: dispatch on how to extract the weight from each item
elif weight_pos is not None:
# weight lives at a fixed tuple/list position; payload is the original item
entries = [(item[weight_pos], item) for item in items] # ty: ignore[not-subscriptable]
elif key is not None:
# custom extractor for arbitrary item types; payload is the original item
entries = [(key(item), item) for item in items]
elif isinstance(items[0], (tuple, list)):
# structured items but caller didn't say how to extract a weight
raise ValueError("Must provide weight_pos or key for tuple list")
else:
is_tuple_list = False

if isinstance(items, dict):
# get keys and values (weights)
keys = list(items)
vals = list(items.values())

# sort weights decreasingly
n_dcs = _rev_argsort_bins(vals)

weights = _get_bins(vals, n_dcs)
keys = _get_bins(keys, n_dcs)

bins = [[]] if is_tuple_list else [{}]
else:
weights = sorted(items, key=lambda x: -x)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of the above logic to get this reverse weights list is simplified to entries.sort(key=lambda e: -e[0]) below

bins = [[]]

# find the valid indices
if lower_bound is not None and upper_bound is not None and lower_bound < upper_bound:
valid_ndcs = filter(
lambda i: lower_bound < weights[i] < upper_bound, range(len(weights))
)
elif lower_bound is not None:
valid_ndcs = filter(lambda i: lower_bound < weights[i], range(len(weights)))
elif upper_bound is not None:
valid_ndcs = filter(lambda i: weights[i] < upper_bound, range(len(weights)))
elif lower_bound is None and upper_bound is None:
valid_ndcs = range(len(weights))
elif lower_bound >= upper_bound:
raise ValueError("lower_bound is greater or equal to upper_bound")

valid_ndcs = list(valid_ndcs)

weights = _get_bins(weights, valid_ndcs)
Copy link
Copy Markdown
Collaborator Author

@curtischong curtischong May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above logic to get all of the valid indices is simplified to:

    if lower_bound is not None:
        entries = [e for e in entries if e[0] > lower_bound]
    if upper_bound is not None:
        entries = [e for e in entries if e[0] < upper_bound]


if isinstance(items, dict):
keys = _get_bins(keys, valid_ndcs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is no longer needed since we operate on the new tuple objects, which automatically carries the keys in the entries list


# prepare array containing the current weight of the bins
weight_sum = [0.0]

# iterate through the weight list, starting with heaviest
for item, weight in enumerate(weights):
if isinstance(items, dict):
item_key = keys[item]

# find candidate bins where the weight might fit
candidate_bins = list(
filter(lambda i: weight_sum[i] + weight <= max_volume, range(len(weight_sum)))
)

if candidate_bins: # if there are candidates where it fits
# find the fullest bin where this item fits and assign it
candidate_index = _argmax_bins(_get_bins(weight_sum, candidate_bins))
b = candidate_bins[candidate_index]

# if this weight doesn't fit in any existent bin
elif item > 0:
# note! if this is the very first item then there is already an
# empty bin open so we don't need to open another one.

# open a new bin
b = len(weight_sum)
weight_sum.append(0.0)
if isinstance(items, dict):
bins.append([] if is_tuple_list else {})
else:
bins.append([])

# if we are at the very first item, use the empty bin already open
else:
b = 0

# put it in
if isinstance(items, dict):
bin_ = bins[b]
if is_tuple_list:
if not isinstance(bin_, list):
raise TypeError("bins contain lists when tuple-list mode is used")
bin_.append(item_key)
elif isinstance(bin_, dict):
bin_[item_key] = weight
else:
bin_ = bins[b]
if not isinstance(bin_, list):
raise TypeError("bins contain lists when items is not dict")
bin_.append(weight)

# increase weight sum of the bin and continue with
# next item
weight_sum[b] += weight

if not is_tuple_list:
return bins
return [[new_dict[item_key] for item_key in bin_keys] for bin_keys in bins]
# plain numeric weights; the item is its own payload
entries = [(weight, weight) for weight in items]

if lower_bound is not None:
entries = [e for e in entries if e[0] > lower_bound]
if upper_bound is not None:
entries = [e for e in entries if e[0] < upper_bound]

# Pack heaviest first, opening a new bin only when nothing fits.
entries.sort(key=lambda e: -e[0])
bin_entries: list[list[tuple[float, Any]]] = []
bin_sums: list[float] = []
for weight, payload in entries:
# get all bin indices that can fit this payload
candidate_bin_indices = [
i for i, s in enumerate(bin_sums) if s + weight <= max_volume
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one thing I noticed that was missing in the original implementation (I didn't add it here). but I think we need to have an assert to ensure that all weights are <= max_volume? bc if there contains an item that exceeds max_volume, we can never fit it into a bin, even a new bin

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we add this line around 120?

    # warn on oversized items — they'll still be placed in a bin by themselves,
    # but that bin will exceed max_volume
    for weight, payload in entries:
        if weight > max_volume:
            logger.warning(
                "item %r has weight %s > max_volume %s; placing in its own bin anyway",
                payload,
                weight,
                max_volume,
            )

Right now, we allow items to overflow the max bin size. some tests fail if we add this line and throw an exception

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually wanted this to be an exception, but I can see cases where people might want the weight to exceed the max volume so maybe it's not necessary to add this in.

]
if candidate_bin_indices:
# get the idx of the most-full bin that has enough space
b = max(candidate_bin_indices, key=lambda i: bin_sums[i])
else: # no bin has enough space. Create a new bin
b = len(bin_sums)
bin_entries.append([])
bin_sums.append(0.0)
bin_entries[b].append((weight, payload))
bin_sums[b] += weight

if is_dict:
return [{payload: weight for weight, payload in bin_} for bin_ in bin_entries]
return [[payload for _, payload in bin_] for bin_ in bin_entries]


def measure_model_memory_forward(state: SimState, model: ModelInterface) -> float:
Expand Down
Loading