Skip to content

Support reduction with 2D mesh#5806

Merged
wujingyue merged 3 commits intomainfrom
wjy/reduction
Jan 14, 2026
Merged

Support reduction with 2D mesh#5806
wujingyue merged 3 commits intomainfrom
wjy/reduction

Conversation

@wujingyue
Copy link
Collaborator

No description provided.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile Overview

Greptile Summary

Extends reduction scheduler to support 2D device meshes by replacing hardcoded single-DIDx assumption with dynamic device dimension counting. Removes the 1D mesh rank restriction and enables tensor sharding across multiple device dimensions (e.g., data parallel + tensor parallel).

Key Changes:

  • Replaced getShardedIterDomain check for DIDx with numDeviceDims() to count all device dimensions dynamically
  • Removed NVF_ERROR_EQ(reduction_tv->getDeviceMesh().rank(), 1) restriction
  • Updated axis calculations to use num_device_dims instead of hardcoded checks
  • Maintained restriction preventing mixing of multi-GPU and 3D scheduling
  • Added comprehensive test validating 2D mesh reduction with proper sharding across mesh_x (DIDx) and mesh_y (DIDy)

Minor Issue:

  • Comment on line 37-40 mentions only "DIDx" but should reference all device dimensions since the code now supports DIDx, DIDy, and DIDz

Confidence Score: 4/5

  • This PR is safe to merge with minimal risk - the refactoring is clean and well-tested
  • The generalization from 1D to 2D mesh support is straightforward and mathematically sound. The dynamic counting approach via numDeviceDims() correctly replaces the hardcoded single-DIDx checks. The test thoroughly validates the new functionality. Only minor issue is an outdated comment that doesn't affect functionality. Score of 4 (not 5) due to the comment requiring update and this being a core scheduler change that warrants careful testing in production
  • No files require special attention - both changes are clean and well-validated

Important Files Changed

File Analysis

Filename Score Overview
csrc/scheduler/reduction_utils.cpp 4/5 Generalized reduction scheduler from 1D to 2D+ device meshes by replacing hardcoded DIDx checks with dynamic numDeviceDims() counting; removed mesh rank restriction; comment slightly outdated
tests/python/multidevice/test_multidevice.py 5/5 Added comprehensive test for 2D mesh reduction with proper sharding validation across both mesh_x and mesh_y dimensions; handles device count validation correctly

Sequence Diagram

sequenceDiagram
    participant Test as test_reduction_with_2d_mesh
    participant Fusion as FusionDefinition
    participant Tensor as TensorView
    participant Mesh as DeviceMesh
    participant Scheduler as scheduleReductionTV
    participant Utils as numDeviceDims
    
    Test->>Mesh: Create 2D mesh (dp_size, tp_size)
    Test->>Fusion: define_tensor([rows, cols])
    Test->>Fusion: sum(inp, [1])
    Test->>Tensor: set_device_mesh(mesh)
    Test->>Tensor: outer_split(0, dp_size)
    Test->>Tensor: parallelize(axis 0, mesh_y/DIDy)
    Test->>Tensor: outer_split(-1, tp_size)
    Test->>Tensor: parallelize(axis -2, mesh_x/DIDx)
    
    Note over Tensor: Loop domain: [DIDy, DIDx, iter_dim, reduce_dim]
    
    Fusion->>Scheduler: scheduleReductionTV(reduction_tv)
    Scheduler->>Utils: numDeviceDims(reduction_tv)
    Utils-->>Scheduler: 2 (DIDy + DIDx)
    
    Note over Scheduler: iter_axis = 2<br/>inner_reduce_axis = 2 + has_iter_axis
    
    Scheduler->>Scheduler: Check !schedule_3d with num_device_dims > 0
    Note over Scheduler: Schedule remaining dims after device dims
    Scheduler-->>Fusion: Scheduled tensor
    
    Fusion->>Test: execute([sharded_input])
    Test->>Test: Verify output matches expected
Loading

@github-actions
Copy link

github-actions bot commented Jan 13, 2026

Review updated until commit 53a0eff

Description

  • Refactor reduction scheduling to support 2D device meshes using numDeviceDims()

  • Simplify axis calculation logic for multi-device scenarios

  • Add validation to prevent mixing 3D scheduling with multi-GPU

  • Add test_reduction_with_2d_mesh to verify 2D mesh reduction functionality

Changes walkthrough

Relevant files
Enhancement
reduction_utils.cpp
Refactor reduction scheduling for 2D mesh support               

csrc/scheduler/reduction_utils.cpp

  • Replace sharded IterDomain check with numDeviceDims() for general
    multi-device support
  • Refactor axis calculation logic using tuple return for outer/inner
    reduce axes
  • Add validation to prevent 3D scheduling conflicts with multi-GPU
  • Simplify iter_axis determination based on number of device dimensions
  • +15/-17 
    Tests
    test_multidevice.py
    Add 2D mesh reduction test case                                                   

    tests/python/multidevice/test_multidevice.py

  • Add test_reduction_with_2d_mesh function for 2D device mesh testing
  • Create 2D DeviceMesh with tensor and data parallelism dimensions
  • Test reduction operations with proper mesh partitioning
  • Verify correctness by comparing with reference tensor operations
  • +42/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Logic Correctness

    The refactored logic for determining outer_reduce_axis and inner_reduce_axis needs careful validation. The new implementation uses numDeviceDims() and a lambda function, but the original logic used explicit sharded domain checking. Verify that both approaches produce identical results, especially for edge cases with different mesh configurations.

    const int64_t num_device_dims = numDeviceDims(reduction_tv);
    const int iter_axis = num_device_dims;
    const auto [outer_reduce_axis, inner_reduce_axis] =
        [&]() -> std::tuple<int, int> {
      if (rparams->schedule_3d) {
        NVF_ERROR_EQ(
            num_device_dims,
            0,
            "Mixing multi-GPU and 3D schedule is not supported at this "
            "moment.");
        return {1, 2};
      } else {
        return {0, num_device_dims + has_iter_axis};
      }
    }();
    Error Handling Consistency

    The error check for "Mixing multi-GPU and 3D schedule is not supported" is preserved but the logic for determining when this error applies has changed. Ensure the error conditions are equivalent between old and new implementations.

    NVF_ERROR_EQ(
        num_device_dims,
        0,
        "Mixing multi-GPU and 3D schedule is not supported at this "
        "moment.");
    Test Coverage

    The new test only covers a specific 2D mesh configuration (2x1 split). Consider adding tests for different mesh shapes and reduction patterns to ensure the implementation works correctly across various 2D mesh configurations.

    @pytest.mark.mpi
    def test_reduction_with_2d_mesh(multidevice_test):
        d = multidevice_test.size
        tp_size = 2
    
        # Skip if d is not divisible by tp_size
        if d % tp_size != 0:
            pytest.skip(f"Number of devices ({d}) must be divisible by tp_size ({tp_size})")
    
        dp_size = d // tp_size
        rank = multidevice_test.rank
    
        mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(dp_size, tp_size))
    
        with FusionDefinition() as fd:
            inp = fd.define_tensor([-1, -1], dtype=DataType.Float, contiguity=True)
            out = fd.ops.sum(inp, [1])
            fd.add_output(out)
    
            inp.set_device_mesh(mesh)
            inp.outer_split(0, dp_size)
            inp.axis(0).parallelize(nvfuser.ParallelType.mesh_y)
            inp.outer_split(-1, tp_size)
            inp.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
    
        dp_rank = rank // tp_size
        tp_rank = rank % tp_size
    
        rows_per_rank, cols_per_rank = 2, 3
        rows, cols = dp_size * rows_per_rank, tp_size * cols_per_rank
        inp_ref = torch.arange(rows * cols).reshape(rows, cols).to(torch.float)
        out_ref = inp_ref.sum([-1])
        inp = inp_ref[
            dp_rank * rows_per_rank : (dp_rank + 1) * rows_per_rank,
            tp_rank * cols_per_rank : (tp_rank + 1) * cols_per_rank,
        ].cuda()
        (out,) = fd.execute([inp])
        torch.testing.assert_close(
            out.cpu(), out_ref[dp_rank * rows_per_rank : (dp_rank + 1) * rows_per_rank]
        )

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from Priya2698 January 13, 2026 04:28
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 13, 2026

    Additional Comments (1)

    csrc/scheduler/reduction_utils.cpp
    comment mentions only DIDx but code now supports all device dimensions (DIDx, DIDy, DIDz)

      // Multidevice scheduling: we assume only the outermost domains can be
      // parallelized with device dimensions (DIDx, DIDy, DIDz) at this point and
      // in that case this reduction scheduler only schedules the remaining domains
      // while leaving the device dimensions unchanged.
    

    @wujingyue wujingyue merged commit b99ef9b into main Jan 14, 2026
    65 checks passed
    @wujingyue wujingyue deleted the wjy/reduction branch January 14, 2026 21:39
    wujingyue added a commit that referenced this pull request Jan 14, 2026
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants