-
Notifications
You must be signed in to change notification settings - Fork 613
Permutation to always return group_size/tokens_per_expert #2613
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
/te_ci jax |
|
/te-ci jax |
Greptile SummaryThis PR modifies the Key changes:
Implementation details:
Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant token_dispatch
participant _token_dispatch
participant _token_dispatch_fwd_rule
participant Kernel as Permutation Kernel
Caller->>token_dispatch: inp, routing_map, num_out_tokens, probs?, align_size?
token_dispatch->>token_dispatch: Compute worst_case_out_tokens
token_dispatch->>_token_dispatch: Forward call with parameters
_token_dispatch->>_token_dispatch_fwd_rule: Execute forward rule
Note over _token_dispatch_fwd_rule: NEW: Always compute tokens_per_expert<br/>from routing_map (column sum)
_token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: tokens_per_expert = sum(routing_map, axis=0)
alt With padding (align_size provided)
_token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: Compute target_tokens_per_expert (aligned)
_token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: Compute pad_offsets
_token_dispatch_fwd_rule->>Kernel: permute_with_mask_map_and_pad()
Kernel-->>_token_dispatch_fwd_rule: output, permuted_probs
_token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: out_tokens_per_expert = target_tokens_per_expert
else Without padding
_token_dispatch_fwd_rule->>Kernel: permute_with_mask_map()
Kernel-->>_token_dispatch_fwd_rule: output, permuted_probs
_token_dispatch_fwd_rule->>_token_dispatch_fwd_rule: out_tokens_per_expert = tokens_per_expert
end
_token_dispatch_fwd_rule-->>_token_dispatch: (output, permuted_probs, row_id_map, pad_offsets, out_tokens_per_expert)
_token_dispatch-->>token_dispatch: Return all outputs
token_dispatch-->>Caller: NEW: tokens_per_expert always returned (not Optional)
|
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Signed-off-by: tdophung <tdophung@nvidia.com>
9a92963 to
7e59509
Compare
|
/te-ci jax |
|
/te_ci jax |
Description
To enable expert parallelism in MOE, we need to perform a ragged all-to-all after permutation to rearrange the permuted tokens (grouped by experts) onto all GPUs such that each GPU only store a subset of the experts. Ragged all-to-all is used because the number of tokens per expert is most often times, not the same between experts. To do this ragged all-to-all operation, we need to provide it with arguments specifying the number of tokens_per_expert, or often called group sizes in maxText.
In this PR, we compute these group sizes as part of the permutation operation and return them by default
#2585
#2536
Type of change
Changes
Please list the changes introduced in this PR:
return tokens_per_expert by default from summing up expert columns in routing map.
Checklist: