Skip to content

Conversation

@LouisYRYJ
Copy link
Contributor

This PR introduces the new Collector class that will make it easy to implement custom hooks, e.g. for EK-FAC and generally allow for more flexible custom methods

batches: list[list[int]] | None = None,
target_modules: set[str] | None = None,
attention_cfgs: dict[str, AttentionConfig] | None = None,
attention_cfgs: dict[str, AttentionConfig] = {},
Copy link
Collaborator

@luciaquirke luciaquirke Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's best to not use mutable objects as default values because they are only instantiated once and so could be shared across uses of the function

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point will change!

cfg=cfg,
)

computer._compute(desc="New worker - Collecting gradients")
Copy link
Collaborator

@luciaquirke luciaquirke Nov 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No underscore in compute since if it's a public fn. I feel like this class/function could also have a more concrete name, I'm not sure what though but in a sense every function is a computer

Abstract base class for collectors that attach forward and backward hooks to model
layers.
Automatically discovers nn.Linear layers in the model, registers hooks during
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nn.Linear layers -> supported modules

implement custom logic.
Assumes model input shape is [N, S, I] where N=batch size, S=sequence length,
I=input dimension.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model -> module?


target_modules: set[str] | None = None
"""
Set of module names to attach hooks to. Should consist only of nn.Linear modules.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix


dtype: torch.dtype

def setup(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this removable btw

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants