Skip to content

Commit a550a65

Browse files
committed
fix: add normalizer bias support. fix trainer callback. add tests
1 parent 71ae7b2 commit a550a65

File tree

3 files changed

+263
-29
lines changed

3 files changed

+263
-29
lines changed

bergson/gradients.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
6868
class AdafactorNormalizer(Normalizer):
6969
"""
7070
Row and column sums of second moments of gradients for a matrix-valued parameter.
71+
72+
Args:
73+
row: Row statistics [O]
74+
col: Column statistics [I]
75+
bias_avg_sq: Optional second moments for bias [O]
7176
"""
7277

7378
row: Tensor # shape [O]
7479
col: Tensor # shape [I]
80+
bias_avg_sq: Tensor | None = None # shape [O]
7581

7682
def __post_init__(self):
7783
assert self.row.ndim == 1, f"Expected 1D tensor for row, got {self.row.ndim}D"
7884
assert self.col.ndim == 1, f"Expected 1D tensor for col, got {self.col.ndim}D"
85+
if self.bias_avg_sq is not None:
86+
assert self.bias_avg_sq.ndim == 1, (
87+
f"Expected 1D tensor for bias_avg_sq, got {self.bias_avg_sq.ndim}D"
88+
)
7989

8090
@torch.compile
8191
def normalize_(
@@ -120,22 +130,29 @@ def to_adam(self) -> "AdamNormalizer":
120130
"""
121131
Convert this Adafactor normalizer to an Adam normalizer by materializing the
122132
rank-one second moment matrix.
133+
134+
Preserves bias_avg_sq if present.
123135
"""
124136
# Compute the second moment matrix as a square matrix of shape [O, I]
125137
# NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
126138
# add it outside the square root. This could cause infs though if there are
127139
# any exactly zero rows or columns, so we should be careful.
128140
avg_sq = torch.outer(self.row, self.col) / self.row.mean()
129-
return AdamNormalizer(avg_sq=avg_sq)
141+
return AdamNormalizer(avg_sq=avg_sq, bias_avg_sq=self.bias_avg_sq)
130142

131143

132144
@dataclass
133145
class AdamNormalizer(Normalizer):
134146
"""
135147
Contains the second moments of the gradients.
148+
149+
Args:
150+
avg_sq: Second moments for weights [O, I]
151+
bias_avg_sq: Optional second moments for bias [O]
136152
"""
137153

138154
avg_sq: Tensor
155+
bias_avg_sq: Tensor | None = None
139156

140157
@torch.compile
141158
def normalize_(
@@ -153,16 +170,19 @@ def to_adafactor(self) -> AdafactorNormalizer:
153170
Convert this Adam normalizer to an Adafactor normalizer, minimizing the
154171
I-divergence (generalized Kullback-Leibler divergence) between the original
155172
and the factored second moments.
173+
174+
Preserves bias_avg_sq if present.
156175
"""
157176
# We assume avg_sq is a square matrix of shape [O, I]
158-
assert (
159-
self.avg_sq.ndim == 2
160-
), f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
177+
assert self.avg_sq.ndim == 2, (
178+
f"Expected 2D tensor for avg_sq, got {self.avg_sq.ndim}D"
179+
)
161180

162181
# Compute row and column means
163182
return AdafactorNormalizer(
164183
row=self.avg_sq.mean(dim=1), # shape [O]
165184
col=self.avg_sq.mean(dim=0), # shape [I]
185+
bias_avg_sq=self.bias_avg_sq, # Preserve bias second moments
166186
)
167187

168188

@@ -551,8 +571,22 @@ def _process_grad(self, module: nn.Module, _, grad_out):
551571
i = getattr(module, LayerAdapter.in_attr(module))
552572
o = getattr(module, LayerAdapter.out_attr(module))
553573

554-
# Pre-scale G by the Adafactor row statistics
574+
# Handle bias gradients if needed (must be computed from raw G)
555575
norm = self.processor.normalizers.get(name)
576+
bias_grad = None
577+
if include_bias:
578+
# Compute bias from raw G (before any normalization)
579+
bias_grad = G.sum(dim=1) # [N, S, O] -> [N, O]
580+
581+
# Normalize bias with appropriate second moments
582+
if (
583+
isinstance(norm, (AdamNormalizer, AdafactorNormalizer))
584+
and hasattr(norm, "bias_avg_sq")
585+
and norm.bias_avg_sq is not None
586+
):
587+
bias_grad = bias_grad / norm.bias_avg_sq.sqrt().add_(1e-8)
588+
589+
# Pre-scale G by the Adafactor row statistics (for weight gradients)
556590
if isinstance(norm, AdafactorNormalizer):
557591
# Compare to the normalize_ method in AdafactorNormalizer
558592
r = norm.row.add(1e-30)
@@ -563,17 +597,15 @@ def _process_grad(self, module: nn.Module, _, grad_out):
563597
# If we are using AdamNormalizer, or including bias gradients
564598
# we need to materialize the full gradient and then project
565599
if isinstance(norm, AdamNormalizer) or include_bias:
566-
567600
P = G.mT @ I # [N, O, S] @ [N, S, I] → [N, O, I]
568601
if isinstance(norm, AdamNormalizer):
569602
# Normalize the gradients using the second moment matrix
570603
P /= norm.avg_sq.sqrt().add_(1e-8)
571604

572-
if include_bias:
573-
# TODO: should we normalize the bias gradients?
574-
# Append the raw bias gradient to the input
605+
if include_bias and bias_grad is not None:
606+
# Append pre-computed and normalized bias gradient
575607
P = torch.cat(
576-
[P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1]
608+
[P, bias_grad.unsqueeze(2)], # [N, O, 1]
577609
dim=2,
578610
)
579611
i += 1

bergson/huggingface.py

Lines changed: 67 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,12 @@ def on_step_end(
239239
**kwargs,
240240
):
241241
self.on_substep_end(args, state, control)
242-
print("Step end")
243242

244243
# Record training order if enabled
245244
if self.order is not None:
246-
assert (
247-
self.batch_indices is not None
248-
), "Batch indices are not available for training order tracking"
245+
assert self.batch_indices is not None, (
246+
"Batch indices are not available for training order tracking"
247+
)
249248

250249
epoch = int(state.epoch or 0)
251250
global_step = state.global_step
@@ -279,32 +278,82 @@ def on_step_end(
279278

280279
# Read normalizers off of the optimizer state. We need to figure out
281280
# what type of optimizer this is first.
281+
# Collect references to both weight and bias second moments per layer
282+
layer_second_moments: dict[str, dict[str, Tensor]] = {}
283+
282284
for group in optimizer.param_groups:
283-
lr_sqrt = group["lr"] ** 0.5
285+
group_lr = group["lr"]
284286

285287
for param in group["params"]:
286-
name = param_to_name[param].removesuffix(".weight")
287-
if name not in self.collector.target_info:
288+
param_name = param_to_name[param]
289+
290+
# Extract layer name (remove .weight or .bias suffix)
291+
if param_name.endswith(".weight"):
292+
param_type = "weight"
293+
layer_name = param_name.removesuffix(".weight")
294+
elif param_name.endswith(".bias"):
295+
param_type = "bias"
296+
layer_name = param_name.removesuffix(".bias")
297+
else:
298+
continue
299+
300+
if layer_name not in self.collector.target_info:
288301
continue
289302

290303
p_state = optimizer.state[param]
291304

305+
# Initialize layer dict if needed, storing this group's learning rate
306+
if layer_name not in layer_second_moments:
307+
layer_second_moments[layer_name] = {"lr": group_lr}
308+
292309
# Adam-like optimizer
293310
if (eas := p_state.get("exp_avg_sq")) is not None:
294-
norm = AdamNormalizer(eas).to_adafactor()
295-
311+
layer_second_moments[layer_name][param_type] = eas
296312
# Adafactor-like optimizer
297313
elif (vr := p_state.get("exp_avg_sq_row")) is not None:
298314
vc = p_state.get("exp_avg_sq_col")
299-
norm = AdafactorNormalizer(vr, vc)
300-
else:
301-
continue
302-
303-
# Scale the gradient by the current learning rate. It's factorized
304-
# so we multiply each factor by the square root of the LR.
305-
norm.row *= lr_sqrt
306-
norm.col *= lr_sqrt
307-
normalizers[name] = norm
315+
if param_type == "weight":
316+
# Factorized second moments for weights
317+
layer_second_moments[layer_name]["row"] = vr
318+
layer_second_moments[layer_name]["col"] = vc
319+
elif param_type == "bias":
320+
# Adafactor stores bias as regular exp_avg_sq
321+
bias_eas = p_state.get("exp_avg_sq")
322+
if bias_eas is not None:
323+
layer_second_moments[layer_name]["bias"] = bias_eas
324+
325+
# Build normalizers from collected second moments
326+
for layer_name, moments in layer_second_moments.items():
327+
lr_sqrt = moments["lr"] ** 0.5
328+
329+
# Adam-like: has weight exp_avg_sq
330+
if "weight" in moments:
331+
weight_eas = moments["weight"]
332+
bias_eas = moments.get("bias") # May be None
333+
334+
# Create Adam normalizer with optional bias, then convert to Adafactor
335+
# TODO: always convert to adafactor?
336+
norm = AdamNormalizer(weight_eas, bias_eas).to_adafactor()
337+
338+
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
339+
norm.row = norm.row * lr_sqrt
340+
norm.col = norm.col * lr_sqrt
341+
if norm.bias_avg_sq is not None:
342+
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
343+
344+
# Adafactor-like: has row/col
345+
elif "row" in moments and "col" in moments:
346+
bias_eas = moments.get("bias") # May be present
347+
norm = AdafactorNormalizer(moments["row"], moments["col"], bias_eas)
348+
# Scale by LR (factorized) - use non-in-place ops to avoid modifying optimizer state
349+
norm.row = norm.row * lr_sqrt
350+
norm.col = norm.col * lr_sqrt
351+
if norm.bias_avg_sq is not None:
352+
norm.bias_avg_sq = norm.bias_avg_sq * (lr_sqrt**2)
353+
else:
354+
continue
355+
356+
normalizers[layer_name] = norm
308357

309358
proc.normalizers = normalizers
310359

0 commit comments

Comments
 (0)