Skip to content

Commit 17f8865

Browse files
committed
validate batch size before collecting
1 parent 700639e commit 17f8865

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

bergson/collection.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def callback(name: str, g: torch.Tensor, indices: list[int]):
7777
attention_cfgs=attention_cfgs,
7878
)
7979

80+
validate_batch_size(model, token_batch_size, collector)
81+
8082
# Allocate space ahead of time for the gradients
8183
grad_sizes = {name: math.prod(s) for name, s in collector.shapes().items()}
8284

@@ -252,3 +254,21 @@ def process_preconditioners(
252254
preconditioners_eigen[name] = (eigval, eigvec)
253255
if rank == 0:
254256
processor.preconditioners_eigen = preconditioners_eigen
257+
258+
259+
def validate_batch_size(
260+
model: PreTrainedModel,
261+
token_batch_size: int | None,
262+
collector: GradientCollector,
263+
):
264+
"""Validate that the specified token batch size fits on device."""
265+
if token_batch_size is None:
266+
return
267+
268+
random_tokens = torch.randint(
269+
0, 10, (1, token_batch_size), device=model.device, dtype=torch.long
270+
)
271+
with collector:
272+
loss = model(random_tokens).logits[0, 0, 0].float()
273+
loss.backward()
274+
model.zero_grad()

0 commit comments

Comments
 (0)