Skip to content

[PyTorch] Add grouped linear op and experimental fusion for grouped MLP#2622

Open
timmoon10 wants to merge 71 commits intoNVIDIA:mainfrom
timmoon10:tmoon/cute-gemm-swiglu
Open

[PyTorch] Add grouped linear op and experimental fusion for grouped MLP#2622
timmoon10 wants to merge 71 commits intoNVIDIA:mainfrom
timmoon10:tmoon/cute-gemm-swiglu

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

timmoon10 and others added 4 commits February 5, 2026 02:18
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +290 to +292
quantizer=fc2_input_quantizers[group_idx],
requires_grad=False,
with_gemm_swizzled_scales=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect grad-required flags

In ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.fuser_forward, swiglu_ctx.input_requires_grad and swiglu_ctx.extra_input_requires_grad are set to True unconditionally (and input_requires_grad is set to requires_grad unconditionally). This will make ScaledSwiGLU.fuser_backward compute grad_input and grad_extra_input even when neither input_ nor scales require grads, which violates autograd semantics and can raise (e.g., scales.detach() passed into the fused kernel, but extra_input_requires_grad=True forces a gradient).

This should be set based on the actual requirements:

  • input_requires_grad = input_.requires_grad
  • swiglu_ctx.extra_input_requires_grad = scales.requires_grad
  • and for FC weights, check each parameter’s requires_grad (not just weight0).

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments