-
Notifications
You must be signed in to change notification settings - Fork 10
Fix bias gradient computations #76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
bergson/gradients.py
Outdated
| .unsqueeze(2) | ||
| .expand(P.shape[0], -1, 1), | ||
| ], | ||
| [P, G.sum(dim=1).unsqueeze(2)], # [N, S, O] -> [N, O] # [N, O, 1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently this just concatenates raw bias gradients to the normalized weights. Adam does have second moments for bias, but to use them we would need to expose them through the Normalizer. Also wasn't sure if there's a linear algebra trick I'm missing. @norabelrose
|
This is fabulous, thank you!! 🙏 Interested to hear what Nora thinks but I reckon exposing second moments for bias through the normalizer would be great |
# Conflicts: # pyproject.toml
|
Running Should add formatting on commit, let me know if that doesn't work for some reason |
oh yeah, it was a problem with the ruff linter. it doesn't fix line length errors (leaves that to the formatter). Will add black back. |
Couple of bug fixes to do with bias:
dim=(0,1)), rather than just the sequence (dim=1).[N, O, I+1] / [O, I]division error).The biases are currently concatenated raw, as I wasn't sure the best way to handle them. More in comment.update:
bias_avg_sqfield toAdafactorNormalizerandAdamNormalizerto keep track of the bias second moments so we can handle bias normalization separately from weight gradients inAdafactorNormalizer.normalize_():Gbefore weight processingModified
GradientCollectorCallback(with help from claude):scale_by_lr(lr)method toAdafactorNormalizer(also fixes bug where optimizer state tensors were being modified in-place)test_optimizer_state_extractionAlso added some unit tests. #75 should probably be merged before this.
Someone better at linear algebra than me should probably have a look at this as well.