Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions transformer_engine/jax/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""
Dispatch tokens to experts based on routing map.
Expand Down Expand Up @@ -101,9 +101,11 @@ def token_dispatch(
pad_offsets : Optional[jnp.ndarray]
Per-expert cumulative padding offsets of shape [num_experts] when using padding,
None otherwise. Pass this to token_combine when unpadding is needed.
target_tokens_per_expert : Optional[jnp.ndarray]
Aligned token counts per expert of shape [num_experts] when using padding,
None otherwise.
tokens_per_expert : jnp.ndarray
Token counts per expert of shape [num_experts]:
- Without padding: actual token counts (sum of routing_map columns)
- With padding: aligned token counts (ceil(actual / align_size) * align_size)
This gives the effective number of tokens per expert in the output buffer.

Note
----
Expand Down Expand Up @@ -151,10 +153,10 @@ def _token_dispatch(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
]:
"""Internal token_dispatch with custom VJP."""
(output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = (
(output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = (
_token_dispatch_fwd_rule(
inp,
routing_map,
Expand All @@ -165,7 +167,7 @@ def _token_dispatch(
use_padding,
)
)
return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert
return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert


def _token_dispatch_fwd_rule(
Expand All @@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule(
Optional[jnp.ndarray],
jnp.ndarray,
Optional[jnp.ndarray],
Optional[jnp.ndarray],
jnp.ndarray,
],
Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool],
]:
Expand Down Expand Up @@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule(

with_probs = probs is not None

if use_padding:
# Compute tokens_per_expert internally from routing_map
# This can be a traced value since output shape uses worst_case_out_tokens
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)
# Compute tokens_per_expert from routing_map (actual counts)
# This is well-optimized by XLA as a simple column-wise reduction
tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32)

if use_padding:
# Calculate aligned token counts per expert
target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype(
jnp.int32
Expand All @@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule(
hidden_size,
align_size=align_size,
)

# Return aligned counts when using padding
out_tokens_per_expert = target_tokens_per_expert
else:
# No padding
pad_offsets = None
target_tokens_per_expert = None

output, permuted_probs = permute_with_mask_map(
inp,
Expand All @@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule(
hidden_size,
)

# Return actual counts when not using padding
out_tokens_per_expert = tokens_per_expert

# Return (primals, residuals)
# out_tokens_per_expert is:
# - target_tokens_per_expert (aligned) when using padding
# - tokens_per_expert (actual) when not using padding
residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs)
return (
output,
permuted_probs,
row_id_map,
pad_offsets,
target_tokens_per_expert,
out_tokens_per_expert,
), residuals


Expand Down
Loading