Skip to content

Commit a0e88c3

Browse files
committed
Fix ASRModelAdapter state passing to F() and parameter count
- Fixed state passing by extracting encoder_adapter and decoder_adapter from the full state dict in adapt_encoder_features and adapt_decoder_features - Fixed expected parameter count from 33664 to 33600 in test_parameter_counts
1 parent b16c188 commit a0e88c3

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

axlearn/audio/adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None,
218218
inputs=(features,),
219219
is_training=is_training,
220220
prng_key=prng_key,
221-
state=state,
221+
state=state["encoder_adapter"],
222222
)
223223
return outputs
224224

@@ -248,7 +248,7 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None,
248248
inputs=(features,),
249249
is_training=is_training,
250250
prng_key=prng_key,
251-
state=state,
251+
state=state["decoder_adapter"],
252252
)
253253
return outputs
254254

axlearn/audio/adapter_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def test_parameter_counts(self):
191191
total_params += np.prod(layer_norm_scale.shape)
192192
total_params += np.prod(layer_norm_bias.shape)
193193

194-
self.assertEqual(total_params, 33664)
194+
self.assertEqual(total_params, 33600)
195195

196196
@parameterized.parameters([True, False])
197197
def test_training_vs_eval_mode(self, is_training: bool):

0 commit comments

Comments
 (0)