Skip to content

Conversation

@vishesh9131
Copy link
Contributor

took me long enough to ship this one

Hey @markblee @jiarui-lu2, it’s been ages since my last merge... feels good to finally ship this one.

New Features

Audio Model Adapters (axlearn/audio/adapter.py, axlearn/audio/adapter_test.py)

  • Added AudioModelAdapter, a general-purpose bottleneck adapter for fine-tuning audio models.
  • Includes down-projection, activation, and optional LayerNorm or BatchNorm.
  • Supports residual connections with small-scale weight initialization (scale=0.01).
  • Works with both ReLU and GELU activations.
  • Introduced ASRModelAdapter for encoder–decoder ASR models.
  • Provides adapt_encoder_features and adapt_decoder_features methods with functional state handling.

SSM Partition Spec Improvements (axlearn/common/ssm.py, axlearn/common/ssm_test.py)

  • Added sequence parallelism support for Mamba partition specs.
  • default_mamba_dim_to_partition_specs now shards the sequence dimension when "seq" is present.
  • default_output_partition_spec now returns a PartitionSpec instead of a dictionary.
  • Sharding logic improved: batch over non-model/seq, sequence over "seq", model over "model".
  • Added six new tests across different mesh configurations.
  • Updated keyword arguments in PallasLinearScanMambaRecurrence for consistency.

Bug Fixes

  • Corrected return type in default_output_partition_spec.
  • Wrapped tensor inputs in tuples for F().
  • Fixed ASR state passing to properly extract child states.
  • Updated parameter assertions to include layer_norm.bias.
  • Removed input_grain_csv_test.py.

Testing

All tests pass with broad coverage.

  • Audio adapter tests cover parameter counts, forward passes, and configuration validation.
  • SSM partition spec tests verify mesh configurations with and without sequence or model parallelism.

Impact

  • Enables efficient fine-tuning for audio models with minimal parameter overhead.
  • Improves scalability of Mamba/SSM models through sequence parallelism.
  • Makes partition specs more flexible and consistent for distributed training.

@vishesh9131 vishesh9131 requested a review from a team as a code owner November 3, 2025 03:59
Introduces AudioModelAdapter and ASRModelAdapter for efficient fine-tuning of audio models, along with comprehensive tests. Adds input_grain_csv_test.py for CSV/TSV input processing tests. Updates SSM partition spec helpers to support sequence parallelism and adds corresponding tests in ssm_test.py. Updates .gitignore to exclude run_specific_test.sh.
The function returns PartitionSpec but was annotated as dict[str, PartitionSpec].
This mismatch was causing CI test failures.
The test file was added without implementing the csv_dataset and tsv_dataset
functions in input_grain.py. This caused import errors in CI.
- Fixed TypeError by wrapping single tensor inputs in tuples for F() calls in both adapter.py and adapter_test.py
- Fixed parameter count assertion by including layer_norm.bias in the count calculation
- Fixed state passing by extracting encoder_adapter and decoder_adapter from the full state dict in adapt_encoder_features and adapt_decoder_features
- Fixed expected parameter count from 33664 to 33600 in test_parameter_counts
- Made prng_key and state required parameters in adapt_encoder_features and adapt_decoder_features
- Removed fallback direct module calls which don't work outside invocation context
- Updated test_direct_call_fallback to pass required prng_key and state parameters
- Install uv and use uv pip install to respect ml-dtypes override
- Fixes dependency conflict with ml-dtypes>=0.5,<0.6 vs tensorflow<0.5.0
- uv pip requires either a venv or --system flag
- actions/setup-python creates venv but uv doesn't auto-detect it
- Use --system to install into the Python environment directly
- jit_mamba_scan needs 6 positional args for shard_map compatibility
- This is a nested function within JAX jit decorator
- pylint 2.17+ flags this as R0917 (too-many-positional-arguments)
…ents

- MambaConfig has 24 parameters to match HuggingFace's PretrainedConfig
- pylint 2.17 added R0917 check which flags this legitimate case
- Add disable comment to suppress the warning
- uv pip install works correctly with actions/setup-python venv
- Only uv pip install --system was causing issues
- This matches the Dockerfile approach which works correctly
- uv doesn't work properly with actions/setup-python in GitHub Actions
- Upstream uses pip successfully with same dependency versions
- This should work correctly as-is
- pip doesn't respect tool.uv override-dependencies
- Set VIRTUAL_ENV to pythonLocation so uv detects the venv
- This allows uv to honor ml-dtypes>=0.5,<0.6 override
- Environment variables don't persist across GitHub Actions steps
- Export VIRTUAL_ENV in the same run block as uv pip install
- This ensures uv can detect the venv and honor ml-dtypes override
- Revert to pip (matching upstream) with legacy resolver
- Legacy resolver can install despite ml-dtypes version conflict
- This allows tensorflow 2.17.1 and jax 0.6.2 to coexist
- isort 7.0 has stricter import ordering than 5.x
- Fixed 5 files to match CI expectations
- Added pylint disables for too-many-positional-arguments in gpu_attention.py
- These are pre-existing code style issues, not related to our changes
@vishesh9131 vishesh9131 requested a review from a team as a code owner November 3, 2025 11:04
- Legacy resolver may skip or break google-cloud-aiplatform installation
- Force reinstall with --no-deps to fix pytype import errors
- This ensures pytype can analyze vertexai_tensorboard.py
- Legacy resolver breaks transformers package installation
- Add transformers==4.51.3 to force reinstall for pytype
- This fixes param_converter.py import errors
- Create .venv and export VIRTUAL_ENV + PATH via GITHUB_ENV/GITHUB_PATH
- Install pip+uv, then install extras with uv (honors ml-dtypes override)
- Remove legacy resolver and post-install hacks (aiplatform/transformers)
- Ensures pytype can resolve imports reliably
@vishesh9131
Copy link
Contributor Author

Hey @samos123
I am reaching out regarding a pull request I opened recently on the AxLearn repository. It appears that the code owners list may not be up to date, as the PR has not shown any indication of being routed to the appropriate reviewers.
Could you please take a look or help direct it to the correct owners? I want to ensure that the contribution aligns with current development practices and reaches the right maintainers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant