Skip to content

Conversation

@tdophung
Copy link
Collaborator

@tdophung tdophung commented Jan 21, 2026

Description

To enable expert parallelism in MOE, we need to perform a ragged all-to-all after permutation to rearrange the permuted tokens (grouped by experts) onto all GPUs such that each GPU only store a subset of the experts. Ragged all-to-all is used because the number of tokens per expert is most often times, not the same between experts. To do this ragged all-to-all operation, we need to provide it with arguments specifying the number of tokens_per_expert, or often called group sizes in maxText.

In this PR, we compute these group sizes as part of the permutation operation and return them by default

#2585
#2536

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:
return tokens_per_expert by default from summing up expert columns in routing map.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@tdophung
Copy link
Collaborator Author

/te_ci jax

@tdophung
Copy link
Collaborator Author

/te-ci jax

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 21, 2026

Greptile Summary

This PR modifies the token_dispatch function to always return tokens_per_expert as a non-optional array, addressing the requirement for expert parallelism in MOE models where ragged all-to-all operations need per-expert token counts.

Key changes:

  • Changed return type from Optional[jnp.ndarray] to jnp.ndarray for the 5th return value
  • tokens_per_expert is now computed for both padded and non-padded cases by summing routing_map columns
  • When using padding: returns aligned token counts (target_tokens_per_expert)
  • When not using padding: returns actual token counts (tokens_per_expert)
  • Documentation updated to reflect that this value is always returned

Implementation details:

  • The computation jnp.sum(routing_map, axis=0) was moved outside the padding conditional block (line 217-219)
  • Well-optimized by XLA as a simple column-wise reduction
  • Maintains backward compatibility since existing tests already unpack all 5 return values

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The change is a straightforward refactoring that makes tokens_per_expert always available. The implementation correctly computes the value for both padded and non-padded cases, maintains backward compatibility (tests already unpack 5 values), and addresses a clear need for expert parallelism support in MOE models.
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/jax/permutation.py Changed tokens_per_expert from optional to always returned, computing it from routing_map for both padded and non-padded cases

Sequence Diagram

sequenceDiagram
    participant Caller
    participant token_dispatch
    participant _token_dispatch
    participant _token_dispatch_fwd_rule
    participant Kernel as Permutation Kernel

    Caller->>token_dispatch: inp, routing_map, num_out_tokens, probs?, align_size?
    token_dispatch->>token_dispatch: Compute worst_case_out_tokens
    token_dispatch->>_token_dispatch: Forward call with parameters
    _token_dispatch->>_token_dispatch_fwd_rule: Execute forward rule
    
    Note over _token_dispatch_fwd_rule: NEW: Always compute tokens_per_expert<br/>from routing_map (column sum)
    _token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: tokens_per_expert = sum(routing_map, axis=0)
    
    alt With padding (align_size provided)
        _token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: Compute target_tokens_per_expert (aligned)
        _token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: Compute pad_offsets
        _token_dispatch_fwd_rule->>Kernel: permute_with_mask_map_and_pad()
        Kernel-->>_token_dispatch_fwd_rule: output, permuted_probs
        _token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: out_tokens_per_expert = target_tokens_per_expert
    else Without padding
        _token_dispatch_fwd_rule->>Kernel: permute_with_mask_map()
        Kernel-->>_token_dispatch_fwd_rule: output, permuted_probs
        _token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: out_tokens_per_expert = tokens_per_expert
    end
    
    _token_dispatch_fwd_rule-->>_token_dispatch: (output, permuted_probs, row_id_map, pad_offsets, out_tokens_per_expert)
    _token_dispatch-->>token_dispatch: Return all outputs
    token_dispatch-->>Caller: NEW: tokens_per_expert always returned (not Optional)
Loading

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung force-pushed the return_tokens_per_expert branch from 9a92963 to 7e59509 Compare January 22, 2026 01:33
@tdophung
Copy link
Collaborator Author

/te-ci jax

@tdophung
Copy link
Collaborator Author

/te_ci jax

@tdophung tdophung merged commit 3d46bf6 into NVIDIA:main Jan 22, 2026
21 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants