-
Notifications
You must be signed in to change notification settings - Fork 10
Collector refactor #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
…P, rest working + passing
209ca74 to
33e427c
Compare
for more information, see https://pre-commit.ci
| batches: list[list[int]] | None = None, | ||
| target_modules: set[str] | None = None, | ||
| attention_cfgs: dict[str, AttentionConfig] | None = None, | ||
| attention_cfgs: dict[str, AttentionConfig] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's best to not use mutable objects as default values because they are only instantiated once and so could be shared across uses of the function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point will change!
| cfg=cfg, | ||
| ) | ||
|
|
||
| computer._compute(desc="New worker - Collecting gradients") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No underscore in compute since if it's a public fn. I feel like this class/function could also have a more concrete name, I'm not sure what though but in a sense every function is a computer
| Abstract base class for collectors that attach forward and backward hooks to model | ||
| layers. | ||
| Automatically discovers nn.Linear layers in the model, registers hooks during |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nn.Linear layers -> supported modules
| implement custom logic. | ||
| Assumes model input shape is [N, S, I] where N=batch size, S=sequence length, | ||
| I=input dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model -> module?
|
|
||
| target_modules: set[str] | None = None | ||
| """ | ||
| Set of module names to attach hooks to. Should consist only of nn.Linear modules. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix
|
|
||
| dtype: torch.dtype | ||
|
|
||
| def setup(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this removable btw
This PR introduces the new Collector class that will make it easy to implement custom hooks, e.g. for EK-FAC and generally allow for more flexible custom methods