@@ -68,14 +68,24 @@ def state_dict(self) -> dict[str, str | Tensor]:
6868class AdafactorNormalizer (Normalizer ):
6969 """
7070 Row and column sums of second moments of gradients for a matrix-valued parameter.
71+
72+ Args:
73+ row: Row statistics [O]
74+ col: Column statistics [I]
75+ bias_avg_sq: Optional second moments for bias [O]
7176 """
7277
7378 row : Tensor # shape [O]
7479 col : Tensor # shape [I]
80+ bias_avg_sq : Tensor | None = None # shape [O]
7581
7682 def __post_init__ (self ):
7783 assert self .row .ndim == 1 , f"Expected 1D tensor for row, got { self .row .ndim } D"
7884 assert self .col .ndim == 1 , f"Expected 1D tensor for col, got { self .col .ndim } D"
85+ if self .bias_avg_sq is not None :
86+ assert self .bias_avg_sq .ndim == 1 , (
87+ f"Expected 1D tensor for bias_avg_sq, got { self .bias_avg_sq .ndim } D"
88+ )
7989
8090 @torch .compile
8191 def normalize_ (
@@ -120,22 +130,29 @@ def to_adam(self) -> "AdamNormalizer":
120130 """
121131 Convert this Adafactor normalizer to an Adam normalizer by materializing the
122132 rank-one second moment matrix.
133+
134+ Preserves bias_avg_sq if present.
123135 """
124136 # Compute the second moment matrix as a square matrix of shape [O, I]
125137 # NOTE: We don't add the epsilon here, since the AdamNormalizer is going to
126138 # add it outside the square root. This could cause infs though if there are
127139 # any exactly zero rows or columns, so we should be careful.
128140 avg_sq = torch .outer (self .row , self .col ) / self .row .mean ()
129- return AdamNormalizer (avg_sq = avg_sq )
141+ return AdamNormalizer (avg_sq = avg_sq , bias_avg_sq = self . bias_avg_sq )
130142
131143
132144@dataclass
133145class AdamNormalizer (Normalizer ):
134146 """
135147 Contains the second moments of the gradients.
148+
149+ Args:
150+ avg_sq: Second moments for weights [O, I]
151+ bias_avg_sq: Optional second moments for bias [O]
136152 """
137153
138154 avg_sq : Tensor
155+ bias_avg_sq : Tensor | None = None
139156
140157 @torch .compile
141158 def normalize_ (
@@ -153,16 +170,19 @@ def to_adafactor(self) -> AdafactorNormalizer:
153170 Convert this Adam normalizer to an Adafactor normalizer, minimizing the
154171 I-divergence (generalized Kullback-Leibler divergence) between the original
155172 and the factored second moments.
173+
174+ Preserves bias_avg_sq if present.
156175 """
157176 # We assume avg_sq is a square matrix of shape [O, I]
158- assert (
159- self .avg_sq .ndim == 2
160- ), f"Expected 2D tensor for avg_sq, got { self . avg_sq . ndim } D"
177+ assert self . avg_sq . ndim == 2 , (
178+ f"Expected 2D tensor for avg_sq, got { self .avg_sq .ndim } D"
179+ )
161180
162181 # Compute row and column means
163182 return AdafactorNormalizer (
164183 row = self .avg_sq .mean (dim = 1 ), # shape [O]
165184 col = self .avg_sq .mean (dim = 0 ), # shape [I]
185+ bias_avg_sq = self .bias_avg_sq , # Preserve bias second moments
166186 )
167187
168188
@@ -551,8 +571,22 @@ def _process_grad(self, module: nn.Module, _, grad_out):
551571 i = getattr (module , LayerAdapter .in_attr (module ))
552572 o = getattr (module , LayerAdapter .out_attr (module ))
553573
554- # Pre-scale G by the Adafactor row statistics
574+ # Handle bias gradients if needed (must be computed from raw G)
555575 norm = self .processor .normalizers .get (name )
576+ bias_grad = None
577+ if include_bias :
578+ # Compute bias from raw G (before any normalization)
579+ bias_grad = G .sum (dim = 1 ) # [N, S, O] -> [N, O]
580+
581+ # Normalize bias with appropriate second moments
582+ if (
583+ isinstance (norm , (AdamNormalizer , AdafactorNormalizer ))
584+ and hasattr (norm , "bias_avg_sq" )
585+ and norm .bias_avg_sq is not None
586+ ):
587+ bias_grad = bias_grad / norm .bias_avg_sq .sqrt ().add_ (1e-8 )
588+
589+ # Pre-scale G by the Adafactor row statistics (for weight gradients)
556590 if isinstance (norm , AdafactorNormalizer ):
557591 # Compare to the normalize_ method in AdafactorNormalizer
558592 r = norm .row .add (1e-30 )
@@ -563,17 +597,15 @@ def _process_grad(self, module: nn.Module, _, grad_out):
563597 # If we are using AdamNormalizer, or including bias gradients
564598 # we need to materialize the full gradient and then project
565599 if isinstance (norm , AdamNormalizer ) or include_bias :
566-
567600 P = G .mT @ I # [N, O, S] @ [N, S, I] → [N, O, I]
568601 if isinstance (norm , AdamNormalizer ):
569602 # Normalize the gradients using the second moment matrix
570603 P /= norm .avg_sq .sqrt ().add_ (1e-8 )
571604
572- if include_bias :
573- # TODO: should we normalize the bias gradients?
574- # Append the raw bias gradient to the input
605+ if include_bias and bias_grad is not None :
606+ # Append pre-computed and normalized bias gradient
575607 P = torch .cat (
576- [P , G . sum ( dim = 1 ). unsqueeze (2 )], # [N, S, O] -> [N, O] # [N, O, 1]
608+ [P , bias_grad . unsqueeze (2 )], # [N, O, 1]
577609 dim = 2 ,
578610 )
579611 i += 1
0 commit comments