-
Notifications
You must be signed in to change notification settings - Fork 53
Add support for TransformerEngine flash attention in WAN #299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@cpersson-amd I've been out on PTO for a month. I'll take a closer look at this next week. Meanwhile, can you update your branch with the latest in main. Thanks. |
src/maxdiffusion/configs/base14.yml
Outdated
|
|
||
| # Parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor'] | ||
| mesh_axes: ['data', 'fsdp_tpu', 'tensor'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why rename this to fsdp_tpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the latest changes in the pyconfig.py file hardcoded the "fsdp" name into the code which necessitated a change of name in all configs. fsdp_tpu was chosen for clarity, let me know if you prefer this be reverted
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer this to be reverted. Can you explain why these changes with the hardcoded fsdp do not apply to gpu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright, I have reverted the change of name. The additional sharding rules in the pyconfig file are not used by the cudnn_te_flash attention function and does not need to be added. The change was initially made to prevent the mismatch between the "fsdp_tpu" and hardcoded "fsdp" names, which would add duplicate rules for e.g. "activation_length". This change can also be reverted without issue.
entrpn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the PR looks good, but I'm still unsure if adding another axes, fsdp_batch, is really necessary. I would prefer not to add it. The other major thing is switching the mesh_axes from data, fsdp, tensor to data, tensor, fsdp.
|
|
||
| # Parallelism | ||
| mesh_axes: ['data', 'fsdp', 'tensor'] | ||
| mesh_axes: ['data', 'tensor', 'fsdp', 'fsdp_batch'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried changing the axis order will introduce a big change to our performance in general. Is there a reason for changing the order of fsdp and tensor?
Also what does fsdp_batch do? Is it really necessary to introduce a new axis?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The only reason for changing the order was to have the, somewhat similar, fsdp and fsdp_batch axes side by side. I have now changed it so that the order is [data, fsdp_batch, fsdp, tensor] so as to align with how it is done in maxtext. I did not realize this could effect performance, my bad.
The current of implementation of "fsdp" is more similar to that of sequence/context parallelism. The new implementation "fsdp_batch" is a more classic fsdp implementation, where the input is sharded across the batch dimension instead of the sequence dimension. The "fsdp_batch" parallelism is substantially faster on GPUs (~10.5% faster when training on PusaV1) compared to the current "fsdp" parallelism. I therefore think it is important to include in this PR.
For the WAN2.1 model I would suggest renaming the parallelisms to more closely reflect how the actual sharding is done, for example:
fsdp -> context (alternatively sequence)
fsdp_batch -> fsdp
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I agree they should be renamed, let me test this branch next week on TPU to make sure it doesn't affect the current performance and then we can do the name change.
This PR implements the following:
The code has been tested on WAN 2.1 (training and inference) and flux (only training) using GPUs.