diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 405d5f7661..438511fa55 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -52,7 +52,7 @@ def token_dispatch( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ]: """ Dispatch tokens to experts based on routing map. @@ -101,9 +101,11 @@ def token_dispatch( pad_offsets : Optional[jnp.ndarray] Per-expert cumulative padding offsets of shape [num_experts] when using padding, None otherwise. Pass this to token_combine when unpadding is needed. - target_tokens_per_expert : Optional[jnp.ndarray] - Aligned token counts per expert of shape [num_experts] when using padding, - None otherwise. + tokens_per_expert : jnp.ndarray + Token counts per expert of shape [num_experts]: + - Without padding: actual token counts (sum of routing_map columns) + - With padding: aligned token counts (ceil(actual / align_size) * align_size) + This gives the effective number of tokens per expert in the output buffer. Note ---- @@ -151,10 +153,10 @@ def _token_dispatch( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ]: """Internal token_dispatch with custom VJP.""" - (output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert), _ = ( + (output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert), _ = ( _token_dispatch_fwd_rule( inp, routing_map, @@ -165,7 +167,7 @@ def _token_dispatch( use_padding, ) ) - return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert + return output, permuted_probs, row_id_map, pad_offsets, tokens_per_expert def _token_dispatch_fwd_rule( @@ -182,7 +184,7 @@ def _token_dispatch_fwd_rule( Optional[jnp.ndarray], jnp.ndarray, Optional[jnp.ndarray], - Optional[jnp.ndarray], + jnp.ndarray, ], Tuple[jnp.ndarray, Optional[jnp.ndarray], int, int, int, bool], ]: @@ -212,11 +214,11 @@ def _token_dispatch_fwd_rule( with_probs = probs is not None - if use_padding: - # Compute tokens_per_expert internally from routing_map - # This can be a traced value since output shape uses worst_case_out_tokens - tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + # Compute tokens_per_expert from routing_map (actual counts) + # This is well-optimized by XLA as a simple column-wise reduction + tokens_per_expert = jnp.sum(routing_map, axis=0).astype(jnp.int32) + if use_padding: # Calculate aligned token counts per expert target_tokens_per_expert = (jnp.ceil(tokens_per_expert / align_size) * align_size).astype( jnp.int32 @@ -242,10 +244,12 @@ def _token_dispatch_fwd_rule( hidden_size, align_size=align_size, ) + + # Return aligned counts when using padding + out_tokens_per_expert = target_tokens_per_expert else: # No padding pad_offsets = None - target_tokens_per_expert = None output, permuted_probs = permute_with_mask_map( inp, @@ -257,14 +261,20 @@ def _token_dispatch_fwd_rule( hidden_size, ) + # Return actual counts when not using padding + out_tokens_per_expert = tokens_per_expert + # Return (primals, residuals) + # out_tokens_per_expert is: + # - target_tokens_per_expert (aligned) when using padding + # - tokens_per_expert (actual) when not using padding residuals = (row_id_map, pad_offsets, num_tokens, num_experts, hidden_size, with_probs) return ( output, permuted_probs, row_id_map, pad_offsets, - target_tokens_per_expert, + out_tokens_per_expert, ), residuals