Skip to content

Conversation

@cpersson-amd
Copy link

This PR implements the following:

  • TransformerEngine flash attention for WAN training and inference.
  • A new fsdp sharding parallelism optimized for use on GPUs.
  • Some minor changes to allow for training on flax version 0.11.2.

The code has been tested on WAN 2.1 (training and inference) and flux (only training) using GPUs.

@google-cla
Copy link

google-cla bot commented Dec 16, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@cpersson-amd cpersson-amd marked this pull request as draft December 17, 2025 00:18
@cpersson-amd cpersson-amd marked this pull request as ready for review December 17, 2025 10:21
@cpersson-amd cpersson-amd reopened this Dec 17, 2025
@entrpn
Copy link
Collaborator

entrpn commented Dec 30, 2025

@cpersson-amd I've been out on PTO for a month. I'll take a closer look at this next week. Meanwhile, can you update your branch with the latest in main. Thanks.


# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'fsdp_tpu', 'tensor']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why rename this to fsdp_tpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the latest changes in the pyconfig.py file hardcoded the "fsdp" name into the code which necessitated a change of name in all configs. fsdp_tpu was chosen for clarity, let me know if you prefer this be reverted

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer this to be reverted. Can you explain why these changes with the hardcoded fsdp do not apply to gpu?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, I have reverted the change of name. The additional sharding rules in the pyconfig file are not used by the cudnn_te_flash attention function and does not need to be added. The change was initially made to prevent the mismatch between the "fsdp_tpu" and hardcoded "fsdp" names, which would add duplicate rules for e.g. "activation_length". This change can also be reverted without issue.

Copy link
Collaborator

@entrpn entrpn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general the PR looks good, but I'm still unsure if adding another axes, fsdp_batch, is really necessary. I would prefer not to add it. The other major thing is switching the mesh_axes from data, fsdp, tensor to data, tensor, fsdp.


# Parallelism
mesh_axes: ['data', 'fsdp', 'tensor']
mesh_axes: ['data', 'tensor', 'fsdp', 'fsdp_batch']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried changing the axis order will introduce a big change to our performance in general. Is there a reason for changing the order of fsdp and tensor?

Also what does fsdp_batch do? Is it really necessary to introduce a new axis?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only reason for changing the order was to have the, somewhat similar, fsdp and fsdp_batch axes side by side. I have now changed it so that the order is [data, fsdp_batch, fsdp, tensor] so as to align with how it is done in maxtext. I did not realize this could effect performance, my bad.

The current of implementation of "fsdp" is more similar to that of sequence/context parallelism. The new implementation "fsdp_batch" is a more classic fsdp implementation, where the input is sharded across the batch dimension instead of the sequence dimension. The "fsdp_batch" parallelism is substantially faster on GPUs (~10.5% faster when training on PusaV1) compared to the current "fsdp" parallelism. I therefore think it is important to include in this PR.

For the WAN2.1 model I would suggest renaming the parallelisms to more closely reflect how the actual sharding is done, for example:
fsdp -> context (alternatively sequence)
fsdp_batch -> fsdp

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I agree they should be renamed, let me test this branch next week on TPU to make sure it doesn't affect the current performance and then we can do the name change.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants