Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jan 13, 2026

Review updated until commit 5c08e12

Description

  • Add new optimization option to cache TMA loaded buffers in registers

  • Enable register caching by default in inner persistent heuristics

  • Skip recomputation when register caching is enabled to reduce overhead

  • Add logic to cache TMA loaded buffers and inline them for better performance

Changes walkthrough

Relevant files
Enhancement
normalization_inner_tma.cpp
Enable TMA register caching optimization                                 

csrc/scheduler/normalization_inner_tma.cpp

  • Added include for range-based operations
  • Enable is_circular_buffer_regs_cached option in heuristics
  • Skip recomputation logic when register caching is enabled
  • Add TMA buffer caching and inlining in warp specialized scheduling
  • +25/-0   
    normalization_inner_tma.h
    Add register caching parameter and utilities                         

    csrc/scheduler/normalization_inner_tma.h

  • Add is_circular_buffer_regs_cached boolean member variable
  • Update equality operator to include new field
  • Update toString() method to display new option
  • Update hash function to include new field in calculation
  • +9/-2     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 No relevant tests
    ⚡ Recommended focus areas for review
    Missing performance validation

    The PR adds a new optimization option that caches TMA loaded buffers in registers to release shared memory barriers faster. However, no performance data or benchmarks are provided to validate the effectiveness of this change. The PR mentions "increased register usage" but doesn't quantify the performance trade-offs. This should be validated with actual performance measurements.

    // Further cache TMA loaded buffer to regs to release shared memory barrier
    // to launch the next TMA load. Inline position is same as TMA loaded tvs.
    if (params->is_circular_buffer_regs_cached) {
      for (auto tv : setup.smem2reg_tvs) {
        if (std::ranges::none_of(tv->getLoopDomain(), [](const IterDomain* id) {
              return id->getParallelType() == ParallelType::BIDx;
            })) {
          continue;
        }
        inlineSelectedAt({tv}, tv, pos_after_bidx);
        exclude_tvs.insert(tv);
      }
    }
    Unused include directive

    The PR adds #include <ranges> but the new code doesn't appear to use any range-based operations. This include should either be removed or the code should be updated to use range-based operations if intended.

    #include <ranges>
    Missing test coverage

    A new optimization feature has been added that could affect scheduling behavior and performance. No new tests are provided to verify the correctness of this feature or to ensure it doesn't introduce regressions. Tests should be added to validate the new caching behavior.

    bool is_circular_buffer_regs_cached = false;

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 13, 2026

    Greptile Overview

    Greptile Summary

    This PR adds register caching support for TMA-loaded buffers in the inner persistent normalization scheduler. When warp specialization is enabled, TMA-loaded shared memory buffers are cached to registers to immediately release shared memory barriers, allowing the next TMA load to proceed without waiting for computation to complete.

    Key Implementation Details:

    • New boolean flag is_circular_buffer_regs_cached controls the optimization
    • Flag is set to true only when warp specialization conditions are met (n_stages >= 2 and bdimx == 128)
    • In setupPersistentSchedule(), skips recomputation logic when flag is enabled to keep all data in registers
    • In scheduleInnerPersistentWarpSpecialized(), applies proper inlining for cached tensors with BIDx parallelization
    • Trade-off: Improves TMA pipeline throughput at the cost of increased register usage

    Changes:

    • Added is_circular_buffer_regs_cached parameter to InnerNormTmaParams class with proper integration in sameAs(), toString(), and hash() methods
    • Modified heuristics to enable flag when warp specialization is active
    • Added conditional logic to skip recomputation and configure inline positions
    • Included <ranges> header for std::ranges::none_of usage

    Minor Issue:

    • One grammar issue in comment (line 224): "increased" should be "increases"

    Confidence Score: 4/5

    • This PR is safe to merge with minimal risk - the optimization is properly gated and follows existing patterns in the codebase.
    • The implementation is well-structured with proper conditional gating. The flag is only enabled in specific warp specialization scenarios, and all related methods (sameAs, toString, hash) are correctly updated. The logic correctly skips recomputation and configures inlining positions. The only issue found is a minor grammar error in a comment. Score is 4/5 due to the minor style issue.
    • No files require special attention - both files have clean, straightforward changes that are consistent with the existing codebase patterns.

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/normalization_inner_tma.h 5/5 Added is_circular_buffer_regs_cached flag with proper integration into sameAs(), toString(), and hash() methods. Clean and consistent changes.
    csrc/scheduler/normalization_inner_tma.cpp 4/5 Implements register caching optimization for warp-specialized TMA loads. Logic is sound with proper gating, though has one minor grammar issue in comment. Adds <ranges> header for std::ranges::none_of.

    Sequence Diagram

    sequenceDiagram
        participant Heuristics as getInnerPersistentHeuristics
        participant Setup as setupPersistentSchedule
        participant Scheduler as scheduleInnerPersistentWarpSpecialized
        participant TMA as TMA Load
    
        Note over Heuristics: Check if warp specialization<br/>conditions met (n_stages >= 2, bdimx == 128)
        
        alt Warp Specialization Enabled
            Heuristics->>Heuristics: Set circular_buffer_options
            Heuristics->>Heuristics: Set is_circular_buffer_regs_cached = true
        else No Warp Specialization
            Heuristics->>Heuristics: is_circular_buffer_regs_cached = false (default)
        end
        
        Heuristics->>Setup: Pass params with flag
        
        Setup->>Setup: Cache inputs (cacheInputs)
        Setup->>Setup: Create TMA loads to shared memory
        Setup->>Setup: Create register cache (cacheAfter)
        
        alt is_circular_buffer_regs_cached = true
            Note over Setup: Skip recomputation logic<br/>to keep all data in registers
            Setup->>Setup: continue (no recompute)
        else is_circular_buffer_regs_cached = false
            Setup->>Setup: Recompute from smem for each consumer
        end
        
        Setup->>Scheduler: Return setup with smem2reg_tvs
        
        alt is_circular_buffer_regs_cached = true
            Scheduler->>Scheduler: Filter smem2reg_tvs with BIDx
            Scheduler->>Scheduler: Inline at pos_after_bidx
            Note over Scheduler,TMA: Cached registers immediately<br/>release shared memory barrier
            TMA->>TMA: Next TMA load can proceed
        end
        
        Note over Scheduler: Result: Improved TMA pipelining<br/>at cost of increased register usage
    
    Loading

    greptile-apps[bot]

    This comment was marked as outdated.

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    No files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile


    // If regs cache is enabled, no need to further recompute from smem as
    // we want to cache all tma loaded buffers to regs to immediately release
    // the shared memory barrier to launch the next TMA load. Note that, this
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    grammar: "increased" should be "increases"

    Suggested change
    // the shared memory barrier to launch the next TMA load. Note that, this
    // increased register usage.

    @liqiangxl liqiangxl requested a review from rdspring1 January 13, 2026 20:04
    @liqiangxl liqiangxl merged commit 291ecaa into main Jan 13, 2026
    63 checks passed
    @liqiangxl liqiangxl deleted the llu/ws_inner_persistent_opt1 branch January 13, 2026 21:20
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants