Skip to content

Conversation

@baberabb
Copy link

@baberabb baberabb commented Nov 14, 2025

Couple of bug fixes to do with bias:

  • The gradients were contracted over both the batch and sequence dimension (dim=(0,1)), rather than just the sequence (dim=1).
  • Normalize weights with Adam before concatenating bias to avoid shape mismatch ([N, O, I+1] / [O, I] division error). The biases are currently concatenated raw, as I wasn't sure the best way to handle them. More in comment.

update:

  • Added bias_avg_sq field to AdafactorNormalizer and AdamNormalizer to keep track of the bias second moments so we can handle bias normalization separately from weight gradients in AdafactorNormalizer.normalize_():
    • Normalize bias from raw gradient G before weight processing
    • Sum bias gradients over sequence dimension
    • Append normalized bias as extra column when include_bias=True

Modified GradientCollectorCallback (with help from claude):

  • Extract bias second moments from both adam and adafactor optimziers
  • added scale_by_lr(lr) method to AdafactorNormalizer (also fixes bug where optimizer state tensors were being modified in-place)
  • added test_optimizer_state_extraction

Also added some unit tests. #75 should probably be merged before this.

Someone better at linear algebra than me should probably have a look at this as well.

.unsqueeze(2)
.expand(P.shape[0], -1, 1),
],
[P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1]
Copy link
Author

@baberabb baberabb Nov 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently this just concatenates raw bias gradients to the normalized weights. Adam does have second moments for bias, but to use them we would need to expose them through the Normalizer. Also wasn't sure if there's a linear algebra trick I'm missing. @norabelrose

@luciaquirke
Copy link
Collaborator

This is fabulous, thank you!! 🙏 Interested to hear what Nora thinks but I reckon exposing second moments for bias through the normalizer would be great

# Conflicts:
#	pyproject.toml
@luciaquirke
Copy link
Collaborator

Running

pip install -e ".[dev]"
pre-commit install

Should add formatting on commit, let me know if that doesn't work for some reason

@baberabb
Copy link
Author

Running

pip install -e ".[dev]"
pre-commit install

Should add formatting on commit, let me know if that doesn't work for some reason

oh yeah, it was a problem with the ruff linter. it doesn't fix line length errors (leaves that to the formatter). Will add black back.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants