Skip to content

Commit 76b15fe

Browse files
committed
Remove fallback direct calls from ASRModelAdapter
- Made prng_key and state required parameters in adapt_encoder_features and adapt_decoder_features - Removed fallback direct module calls which don't work outside invocation context - Updated test_direct_call_fallback to pass required prng_key and state parameters
1 parent a0e88c3 commit 76b15fe

File tree

2 files changed

+28
-34
lines changed

2 files changed

+28
-34
lines changed

axlearn/audio/adapter.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]):
195195
),
196196
)
197197

198-
def adapt_encoder_features(self, features, *, is_training=False, prng_key=None, state=None):
198+
def adapt_encoder_features(self, features, *, is_training=False, prng_key, state):
199199
"""Apply adaptation to encoder features.
200200
201201
Args:
@@ -211,21 +211,16 @@ def adapt_encoder_features(self, features, *, is_training=False, prng_key=None,
211211
if not cfg.adapt_encoder:
212212
return features
213213

214-
# Use functional API if state and prng_key are provided
215-
if state is not None and prng_key is not None:
216-
outputs, _ = F(
217-
self.encoder_adapter,
218-
inputs=(features,),
219-
is_training=is_training,
220-
prng_key=prng_key,
221-
state=state["encoder_adapter"],
222-
)
223-
return outputs
224-
225-
# Fall back to direct call if no state/prng_key
226-
return self.encoder_adapter(features)
214+
outputs, _ = F(
215+
self.encoder_adapter,
216+
inputs=(features,),
217+
is_training=is_training,
218+
prng_key=prng_key,
219+
state=state["encoder_adapter"],
220+
)
221+
return outputs
227222

228-
def adapt_decoder_features(self, features, *, is_training=False, prng_key=None, state=None):
223+
def adapt_decoder_features(self, features, *, is_training=False, prng_key, state):
229224
"""Apply adaptation to decoder features.
230225
231226
Args:
@@ -241,16 +236,11 @@ def adapt_decoder_features(self, features, *, is_training=False, prng_key=None,
241236
if not cfg.adapt_decoder or not hasattr(self, "decoder_adapter"):
242237
return features
243238

244-
# Use functional API if state and prng_key are provided
245-
if state is not None and prng_key is not None:
246-
outputs, _ = F(
247-
self.decoder_adapter,
248-
inputs=(features,),
249-
is_training=is_training,
250-
prng_key=prng_key,
251-
state=state["decoder_adapter"],
252-
)
253-
return outputs
254-
255-
# Fall back to direct call if no state/prng_key
256-
return self.decoder_adapter(features)
239+
outputs, _ = F(
240+
self.decoder_adapter,
241+
inputs=(features,),
242+
is_training=is_training,
243+
prng_key=prng_key,
244+
state=state["decoder_adapter"],
245+
)
246+
return outputs

axlearn/audio/adapter_test.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -364,12 +364,16 @@ def test_direct_call_fallback(self):
364364
)
365365
layer = cfg.set(name="test").instantiate(parent=None)
366366

367-
# Initialize params (required for layer setup, but not used in this direct call test)
368-
_ = layer.initialize_parameters_recursively(jax.random.PRNGKey(123))
369-
encoder_features = jax.random.normal(
370-
jax.random.PRNGKey(456), (batch_size, seq_len, encoder_dim)
371-
)
367+
prng_key = jax.random.PRNGKey(123)
368+
prng_key, init_key, input_key = jax.random.split(prng_key, num=3)
369+
layer_params = layer.initialize_parameters_recursively(init_key)
370+
encoder_features = jax.random.normal(input_key, (batch_size, seq_len, encoder_dim))
372371

373-
adapted_features = layer.adapt_encoder_features(encoder_features, is_training=True)
372+
adapted_features = layer.adapt_encoder_features(
373+
encoder_features,
374+
is_training=True,
375+
prng_key=prng_key,
376+
state=layer_params,
377+
)
374378

375379
self.assertEqual(adapted_features.shape, encoder_features.shape)

0 commit comments

Comments
 (0)