You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Extends reduction scheduler to support 2D device meshes by replacing hardcoded single-DIDx assumption with dynamic device dimension counting. Removes the 1D mesh rank restriction and enables tensor sharding across multiple device dimensions (e.g., data parallel + tensor parallel).
Key Changes:
Replaced getShardedIterDomain check for DIDx with numDeviceDims() to count all device dimensions dynamically
Updated axis calculations to use num_device_dims instead of hardcoded checks
Maintained restriction preventing mixing of multi-GPU and 3D scheduling
Added comprehensive test validating 2D mesh reduction with proper sharding across mesh_x (DIDx) and mesh_y (DIDy)
Minor Issue:
Comment on line 37-40 mentions only "DIDx" but should reference all device dimensions since the code now supports DIDx, DIDy, and DIDz
Confidence Score: 4/5
This PR is safe to merge with minimal risk - the refactoring is clean and well-tested
The generalization from 1D to 2D mesh support is straightforward and mathematically sound. The dynamic counting approach via numDeviceDims() correctly replaces the hardcoded single-DIDx checks. The test thoroughly validates the new functionality. Only minor issue is an outdated comment that doesn't affect functionality. Score of 4 (not 5) due to the comment requiring update and this being a core scheduler change that warrants careful testing in production
No files require special attention - both changes are clean and well-validated
Important Files Changed
File Analysis
Filename
Score
Overview
csrc/scheduler/reduction_utils.cpp
4/5
Generalized reduction scheduler from 1D to 2D+ device meshes by replacing hardcoded DIDx checks with dynamic numDeviceDims() counting; removed mesh rank restriction; comment slightly outdated
tests/python/multidevice/test_multidevice.py
5/5
Added comprehensive test for 2D mesh reduction with proper sharding validation across both mesh_x and mesh_y dimensions; handles device count validation correctly
Sequence Diagram
sequenceDiagram
participant Test as test_reduction_with_2d_mesh
participant Fusion as FusionDefinition
participant Tensor as TensorView
participant Mesh as DeviceMesh
participant Scheduler as scheduleReductionTV
participant Utils as numDeviceDims
Test->>Mesh: Create 2D mesh (dp_size, tp_size)
Test->>Fusion: define_tensor([rows, cols])
Test->>Fusion: sum(inp, [1])
Test->>Tensor: set_device_mesh(mesh)
Test->>Tensor: outer_split(0, dp_size)
Test->>Tensor: parallelize(axis 0, mesh_y/DIDy)
Test->>Tensor: outer_split(-1, tp_size)
Test->>Tensor: parallelize(axis -2, mesh_x/DIDx)
Note over Tensor: Loop domain: [DIDy, DIDx, iter_dim, reduce_dim]
Fusion->>Scheduler: scheduleReductionTV(reduction_tv)
Scheduler->>Utils: numDeviceDims(reduction_tv)
Utils-->>Scheduler: 2 (DIDy + DIDx)
Note over Scheduler: iter_axis = 2<br/>inner_reduce_axis = 2 + has_iter_axis
Scheduler->>Scheduler: Check !schedule_3d with num_device_dims > 0
Note over Scheduler: Schedule remaining dims after device dims
Scheduler-->>Fusion: Scheduled tensor
Fusion->>Test: execute([sharded_input])
Test->>Test: Verify output matches expected
The refactored logic for determining outer_reduce_axis and inner_reduce_axis needs careful validation. The new implementation uses numDeviceDims() and a lambda function, but the original logic used explicit sharded domain checking. Verify that both approaches produce identical results, especially for edge cases with different mesh configurations.
constint64_t num_device_dims = numDeviceDims(reduction_tv);
constint iter_axis = num_device_dims;
constauto [outer_reduce_axis, inner_reduce_axis] =
[&]() -> std::tuple<int, int> {
if (rparams->schedule_3d) {
NVF_ERROR_EQ(
num_device_dims,
0,
"Mixing multi-GPU and 3D schedule is not supported at this ""moment.");
return {1, 2};
} else {
return {0, num_device_dims + has_iter_axis};
}
}();
The error check for "Mixing multi-GPU and 3D schedule is not supported" is preserved but the logic for determining when this error applies has changed. Ensure the error conditions are equivalent between old and new implementations.
NVF_ERROR_EQ(
num_device_dims,
0,
"Mixing multi-GPU and 3D schedule is not supported at this ""moment.");
The new test only covers a specific 2D mesh configuration (2x1 split). Consider adding tests for different mesh shapes and reduction patterns to ensure the implementation works correctly across various 2D mesh configurations.
@pytest.mark.mpideftest_reduction_with_2d_mesh(multidevice_test):
d=multidevice_test.sizetp_size=2# Skip if d is not divisible by tp_sizeifd%tp_size!=0:
pytest.skip(f"Number of devices ({d}) must be divisible by tp_size ({tp_size})")
dp_size=d//tp_sizerank=multidevice_test.rankmesh=nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(dp_size, tp_size))
withFusionDefinition() asfd:
inp=fd.define_tensor([-1, -1], dtype=DataType.Float, contiguity=True)
out=fd.ops.sum(inp, [1])
fd.add_output(out)
inp.set_device_mesh(mesh)
inp.outer_split(0, dp_size)
inp.axis(0).parallelize(nvfuser.ParallelType.mesh_y)
inp.outer_split(-1, tp_size)
inp.axis(-2).parallelize(nvfuser.ParallelType.mesh_x)
dp_rank=rank//tp_sizetp_rank=rank%tp_sizerows_per_rank, cols_per_rank=2, 3rows, cols=dp_size*rows_per_rank, tp_size*cols_per_rankinp_ref=torch.arange(rows*cols).reshape(rows, cols).to(torch.float)
out_ref=inp_ref.sum([-1])
inp=inp_ref[
dp_rank*rows_per_rank : (dp_rank+1) *rows_per_rank,
tp_rank*cols_per_rank : (tp_rank+1) *cols_per_rank,
].cuda()
(out,) =fd.execute([inp])
torch.testing.assert_close(
out.cpu(), out_ref[dp_rank*rows_per_rank : (dp_rank+1) *rows_per_rank]
)
csrc/scheduler/reduction_utils.cpp
comment mentions only DIDx but code now supports all device dimensions (DIDx, DIDy, DIDz)
// Multidevice scheduling: we assume only the outermost domains can be
// parallelized with device dimensions (DIDx, DIDy, DIDz) at this point and
// in that case this reduction scheduler only schedules the remaining domains
// while leaving the device dimensions unchanged.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
No description provided.