Skip to content

Simplify canonicalizeReduction and generalize it for multi-dimensional sharding#5794

Merged
wujingyue merged 9 commits intomainfrom
wjy/merge
Jan 13, 2026
Merged

Simplify canonicalizeReduction and generalize it for multi-dimensional sharding#5794
wujingyue merged 9 commits intomainfrom
wjy/merge

Conversation

@wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 11, 2026

More thoroughly tested by #5806

@wujingyue
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jan 11, 2026

Review updated until commit fafcea9

Description

  • Simplify canonicalizeReduction by removing internal helper functions mergeReduction and mergeNonReduction

  • Generalize canonicalizeReduction to support multi-dimensional sharding with better parallel type ordering

  • Update rankOfParallelType to handle DIDz, DIDy, DIDx ordering for multi-dimensional sharding

  • Replace specific merge logic with generic merge_all predicate-based approach

Changes walkthrough

Relevant files
Enhancement
utils.cpp
Update parallel type ranking for multi-dimensional sharding

csrc/multidevice/utils.cpp

  • Update rankOfParallelType function to reorder parallel types as
    Stream=0, DIDz=1, DIDy=2, DIDx=3
  • Update comments to clarify ordering rationale for multi-dimensional
    sharding support
  • +7/-4     
    utils.cpp
    Simplify and generalize canonicalizeReduction for multi-dimensional
    sharding

    csrc/scheduler/utils.cpp

  • Remove mergeReduction and mergeNonReduction helper functions
    (internalized into canonicalizeReduction)
  • Completely rewrite canonicalizeReduction with generic merge_all
    predicate-based approach
  • Add support for multi-dimensional sharding through
    reorderParallelizedToFront integration
  • Maintain same functional behavior while simplifying implementation
  • +44/-72 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    Lambda capture safety

    The lambda function merge_all captures tv and num_ordered_dims by reference. While tv remains constant, num_ordered_dims is modified within the lambda's scope (lines 1200, 1206), which could lead to unexpected behavior. Consider capturing by value or restructuring to avoid potential issues.

    auto merge_all = [&](auto pred) -> int64_t {
      int64_t merged = -1;
      for (int64_t i :
           arange(num_ordered_dims, tv->nDims()) | std::views::reverse) {
        if (pred(tv->axis(i))) {
          if (merged >= 0) {
            tv->merge(i, merged);
          }
          merged = i;
        }
      }
      return merged;
    };
    Missing test coverage

    This is a significant refactoring that changes the core logic of canonicalizeReduction. The PR should include comprehensive tests covering both the 2D and 3D scheduling cases, as well as edge cases with different parallel types and dimension configurations.

    std::pair<bool, bool> canonicalizeReduction(
        Fusion* fusion,
        TensorView* tv,
        bool schedule_3d) {
      NVF_ERROR(tv != nullptr);
    
      if (schedule_3d) {
        NVF_ERROR_EQ(merge_3d(tv), 3, "Tried 3D merge, but result is not 3D.");
        if (tv->axis(1)->isBroadcast()) {
          NVF_ERROR(
              !tv->axis(0)->isBroadcast(),
              "3D reduction with first two merged axes broadcast should be 2D "
              "reduction.");
          tv->reorder({{0, 1}});
        }
        return {true, true};
      }
    
      // Merge all reductions and all non-reductions, and reorder them to
      // [DIDs/Streams..., merged non-reduction, merged reduction]. Merging happens
      // incrementally -- first the parallel IterDomains, then the non-reductions,
      // then the reductions.
      //
      // At this stage of scheduling, they can only be DIDs or Streams.
      std::unordered_map<int64_t, int64_t> reorder_map =
          reorderParallelizedToFront(tv);
      auto num_ordered_dims = std::ssize(reorder_map);
    
      // This helper function merges not-yet-ordered IterDomains that satisfy the
      // predicate from back to front. Returns the index of the last merged
      // IterDomain, or -1 if no IterDomains were merged.
      auto merge_all = [&](auto pred) -> int64_t {
        int64_t merged = -1;
        for (int64_t i :
             arange(num_ordered_dims, tv->nDims()) | std::views::reverse) {
          if (pred(tv->axis(i))) {
            if (merged >= 0) {
              tv->merge(i, merged);
            }
            merged = i;
          }
        }
        return merged;
      };
    
      int64_t merged_non_reduction =
          merge_all([](IterDomain* id) { return !id->isReduction(); });
      if (merged_non_reduction >= 0) {
        tv->reorder({{merged_non_reduction, num_ordered_dims}});
        num_ordered_dims++;
      }
    
      int64_t merged_reduction = merge_all([](IterDomain* id) { return true; });
      if (merged_reduction >= 0) {
        tv->reorder({{merged_reduction, num_ordered_dims}});
        num_ordered_dims++;
      }
    
      NVF_ERROR_EQ(num_ordered_dims, tv->nDims(), "Did not merge all IterDomains.");
      return {merged_non_reduction >= 0, merged_reduction >= 0};
    }

    Test failures

    • (High, 95) CUDA driver version too old for runtime – nvFuser matmul/top-k test suites on dlcluster_h100

      Test Name H100 Source
      Ampere/MmaTest.SingleTile/Ampere_16_8_16__bfloat Link
      ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/2048_1_1_0 Link
      BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread4 Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype_double Link
      ClusterReductionTest.SimpleFusionNotAllReduce/cluster_4_dtype_double Link
      CutlassExecutorTest.Nvfp4Matmul_BiasEpilogue Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/KK_512_256_128_MmaMacro_m64_n128_k16_splitk_2 Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySum/MK_512_256_128_MmaMacro_m128_n128_k16_tma_store Link
      General/HopperPlusMatmulSchedulerTest.FusedMultiplySumBiasNeg/MN_512_256_128_MmaMacro_m64_n128_k16_tma_store_splitk_2 Link
      GreedySchedulerTest.ScanNonLocalOutput Link
      ... with 85 more test failures omitted. Check internal logs.
    • (High, 1) Outdated NVIDIA driver on dlcluster_h100 causes CUDA initialization failure in RNG tests

      Test Name H100 Source
      RNGTest.BroadcastingRNG Link

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 11, 2026

    Greptile Overview

    Greptile Summary

    This PR successfully refactors canonicalizeReduction in scheduler/utils.cpp, eliminating approximately 65 lines of code by consolidating two separate helper functions (mergeReduction and mergeNonReduction) into a single elegant merge_all lambda function.

    Key Improvements

    1. Bug Fix - Stream Handling

    The old code only checked isDeviceDim() when preserving parallel dimensions, which excludes Stream-parallelized IterDomains. This would cause Stream dimensions to be incorrectly merged with other non-reduction dimensions. The new code correctly uses reorderParallelizedToFront, which handles both DIDs and Streams as parallelized dimensions that should remain at the front.

    2. Multi-Dimensional Sharding Support

    The change in multidevice/utils.cpp establishes a clear ordering between DIDs: Stream(0) < DIDz(1) < DIDy(2) < DIDx(3). Previously, all DIDs had the same rank (1), which would not work correctly with multiple DIDs present simultaneously. This ordering enables proper support for multi-dimensional sharding scenarios.

    3. Code Simplification

    The refactored implementation introduces a reusable merge_all lambda that accepts a predicate to selectively merge IterDomains, eliminating code duplication between the old mergeReduction and mergeNonReduction functions.

    Implementation Details

    The new canonicalizeReduction works in three phases:

    1. Reorder parallel dimensions to front: Uses reorderParallelizedToFront to move all parallelized IterDomains (Stream, DIDs) to the beginning in rank order
    2. Merge non-reduction dimensions: Merges all non-reduction IterDomains that haven't been ordered yet
    3. Merge reduction dimensions: Merges all remaining IterDomains (reductions)

    The final layout is: [DIDs/Streams..., merged non-reduction, merged reduction]

    Edge Cases Verified

    The implementation correctly handles:

    • Empty tensors (0 dimensions)
    • Tensors with only reductions or only iterations
    • Multiple parallel dimensions of different types
    • Broadcast dimensions (treated as non-reductions)
    • Parallel dimensions that are also reductions

    Confidence Score: 5/5

    • This PR is safe to merge with high confidence - it fixes a real bug while simplifying the codebase.
    • Score of 5 reflects a well-executed refactoring that improves code quality while fixing a Stream handling bug. The logic has been thoroughly analyzed across multiple edge cases, the implementation is elegant and maintainable, and the changes are backward compatible with existing tests. The new ordering of DIDs properly supports multi-dimensional sharding as intended.
    • No files require special attention - both changes are well-implemented and correct.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/utils.cpp 5/5 Refactored canonicalizeReduction to fix Stream handling bug and support multi-dimensional sharding. The new implementation is cleaner and more maintainable.
    csrc/multidevice/utils.cpp 5/5 Updated rankOfParallelType to establish clear ordering between DIDs (DIDz < DIDy < DIDx) to support multi-dimensional sharding properly.

    Sequence Diagram

    sequenceDiagram
        participant Caller
        participant canonicalizeReduction
        participant reorderParallelizedToFront
        participant merge_all
        participant TensorView
    
        Caller->>canonicalizeReduction: canonicalizeReduction(fusion, tv, schedule_3d)
        
        alt schedule_3d == true
            canonicalizeReduction->>TensorView: merge_3d(tv)
            TensorView-->>canonicalizeReduction: 3D merged result
            canonicalizeReduction->>TensorView: handle broadcast reordering
            canonicalizeReduction-->>Caller: {true, true}
        else schedule_3d == false
            canonicalizeReduction->>reorderParallelizedToFront: reorderParallelizedToFront(tv)
            Note over reorderParallelizedToFront: Orders dims by rank:<br/>Stream(0) < DIDz(1) < DIDy(2) < DIDx(3)
            reorderParallelizedToFront->>TensorView: reorder parallel dims to front
            reorderParallelizedToFront-->>canonicalizeReduction: reorder_map (num_ordered_dims)
            
            Note over canonicalizeReduction: First merge_all: non-reductions
            canonicalizeReduction->>merge_all: merge_all(pred: !isReduction)
            loop For each unordered dim (reverse)
                merge_all->>TensorView: Check if !isReduction()
                alt Satisfies predicate
                    merge_all->>TensorView: merge(i, merged)
                end
            end
            merge_all-->>canonicalizeReduction: merged_non_reduction index
            
            alt merged_non_reduction >= 0
                canonicalizeReduction->>TensorView: reorder to num_ordered_dims
                Note over canonicalizeReduction: Increment num_ordered_dims
            end
            
            Note over canonicalizeReduction: Second merge_all: all remaining
            canonicalizeReduction->>merge_all: merge_all(pred: always true)
            loop For each remaining dim (reverse)
                merge_all->>TensorView: merge(i, merged)
            end
            merge_all-->>canonicalizeReduction: merged_reduction index
            
            alt merged_reduction >= 0
                canonicalizeReduction->>TensorView: reorder to num_ordered_dims
                Note over canonicalizeReduction: Increment num_ordered_dims
            end
            
            canonicalizeReduction->>canonicalizeReduction: Assert num_ordered_dims == nDims()
            canonicalizeReduction-->>Caller: {has_iter, has_reduction}
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Collaborator

    @liqiangxl liqiangxl left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    Copy link
    Collaborator

    @liqiangxl liqiangxl left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Please fix failed tests.

    Base automatically changed from wjy/canon to main January 12, 2026 17:11
    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    @wujingyue wujingyue requested a review from liqiangxl January 13, 2026 00:38
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    // We coalesce all reduction axes to the right;
    bool has_red_axis = mergeReduction(tv) > 0;

    bool has_iter_axis = mergeNonReduction(tv) > 0;
    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    There’s temporal coupling between mergeReduction and mergeNonReduction. For example:

    1. mergeReduction must be called before mergeNonReduction.
    2. mergeNonReduction needs to reorder DIDs, while mergeReduction does not.

    This kind of coupling is much easier to reason about if both steps are inlined into a single function.

    @wujingyue wujingyue changed the title Simplify canonicalizeReduction Simplify canonicalizeReduction and generalize it for multi-dimensional sharding. Jan 13, 2026
    @wujingyue wujingyue changed the title Simplify canonicalizeReduction and generalize it for multi-dimensional sharding. Simplify canonicalizeReduction and generalize it for multi-dimensional sharding Jan 13, 2026
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @wujingyue
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Collaborator

    @liqiangxl liqiangxl left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    LGTM.

    @wujingyue wujingyue merged commit ae6746d into main Jan 13, 2026
    62 checks passed
    @wujingyue wujingyue deleted the wjy/merge branch January 13, 2026 15:51
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants