Skip to content

Commit 209ca74

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 25653cd commit 209ca74

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

bergson/collector/collector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,9 @@ def setup(self) -> None:
333333
"""Initialize gradient storage dictionary."""
334334
self.mod_grads = {}
335335

336-
assert isinstance(self.model.device, torch.device), (
337-
"Model device is not set correctly"
338-
)
336+
assert isinstance(
337+
self.model.device, torch.device
338+
), "Model device is not set correctly"
339339

340340
self.save_dtype = (
341341
torch.float32 if self.model.dtype == torch.float32 else torch.float16

bergson/gradients.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,9 +512,8 @@ def _process_grad(self, module: nn.Module, _, grad_out):
512512
G = grad_out[0] # [N, S, O]
513513
I = module._inputs # [N, S, I/q]
514514

515-
516515
name = assert_type(str, module._name)
517-
if name== "h.1.mlp.c_fc":
516+
if name == "h.1.mlp.c_fc":
518517
debugpy.breakpoint()
519518

520519
# different way of checking for bias as above

0 commit comments

Comments
 (0)