Skip to content

[Enh] Improve type closure for primitive func#552

Open
sjfeng1999 wants to merge 22 commits into
mainfrom
pr/enh-type-closure
Open

[Enh] Improve type closure for primitive func#552
sjfeng1999 wants to merge 22 commits into
mainfrom
pr/enh-type-closure

Conversation

@sjfeng1999

Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

Copilot AI review requested due to automatic review settings May 21, 2026 10:43
@sjfeng1999

sjfeng1999 commented May 21, 2026

Copy link
Copy Markdown
Collaborator Author

may need update the same aiter kernels simultaneously.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates FlyDSL’s Python expression-layer wrappers to better preserve (“close over”) DSL types (e.g., Numeric, Vector) when calling primitive ops, and adjusts affected tests/kernels to use the new return-value behavior.

Changes:

  • Wrap primitive-op scalar results back into Numeric types (and propagate this through helpers like get_scalar, get_leaves, ptr_load, memref_load).
  • Centralize memref_load_vec to return a Vector with shape/dtype metadata, and simplify Tensor.load() accordingly.
  • Update unit tests and a few kernel utilities to align with the revised scalar/vector return types and pipeline string formatting.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/unit/test_static_vs_dynamic.py Adjusts dynamic-layout tests to return i32 scalars directly and simplifies pipeline string formatting.
tests/unit/test_layout_algebra.py Updates dynamic test functions to return i32 scalars from fx.get_scalar(...).ir_value() and simplifies pipeline string formatting.
python/flydsl/expr/typing.py Updates IntTuple reconstruction and makes Tensor.load() rely on the updated memref_load_vec wrapper.
python/flydsl/expr/primitive.py Introduces numeric re-wrapping helper and applies it across several primitive ops; moves vector wrapping into memref_load_vec.
python/flydsl/expr/math.py Extends traced math-op wrapping to preserve DSL closure for both Numeric and Vector inputs.
kernels/silu_and_mul_fq.py Simplifies scale-offset computation by relying on the updated fx.get_scalar behavior.
kernels/mfma_preshuffle_pipeline.py Updates crd2idx helper to unwrap int-tuples and cast to index type using the new scalar typing behavior.
kernels/layout_utils.py Updates dynamic-layout crd2idx fallback to unwrap/cast through fx.get_scalar(...).ir_value().

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread kernels/mfma_preshuffle_pipeline.py
Comment thread python/flydsl/expr/primitive.py Outdated
@sjfeng1999 sjfeng1999 changed the title [Enh] Ensure type closure for primitive func [Enh] Improve type closure for primitive func May 25, 2026
@coderfeli

Copy link
Copy Markdown
Collaborator

CI failed. @sjfeng1999

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 25 out of 25 changed files in this pull request and generated no new comments.

xudoyuan and others added 12 commits June 5, 2026 03:39
Unwrap DSL types (Int32, etc.) to raw ir.Value at the entry of
idx2crd, so that downstream arith ops (shrui, andi) receive
proper ir.Value operands instead of Numeric wrappers.

Also replace fragile hasattr/str-based type check with isinstance.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…2stage

Eliminate all index-typed arith.constant and arith.index calls in
mixed_moe_gemm_2stage.py, replacing them with fx.Int32() DSL types.
This prevents type mismatch errors (i32 vs index) when DSL-typed
values (from fx.get, fx.idx2crd) participate in arithmetic with
index-typed constants.

174 occurrences replaced.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
scf.ForOp requires index-typed ir.Value for lower_bound, upper_bound,
step, and iter_args. These were incorrectly converted to fx.Int32() in
the previous commit. Restore them to arith.constant(..., index=True).

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Revert the bulk arith.constant(index=True)→fx.Int32() replacement in
mixed_moe_gemm_2stage.py — the file mixes index and i32 type chains
that cannot be uniformly converted.

Instead, fix the root cause in _try_coerce_rhs (numeric.py): when
Numeric.__mul__ encounters an index-typed ArithValue as the rhs
operand, cast it to Int32 so the operands have matching types. This
resolves the arith.muli(i32, index) type mismatch without changing
any kernel code.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>

@jhinpan jhinpan left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two issues that still apply on the current head. The earlier mixed_moe index/i32 arithmetic concern appears fixed in this version: sorted_m is back to an index constant path.

Comment thread python/flydsl/expr/__init__.py
Comment thread python/flydsl/expr/arith.py
Comment thread kernels/mixed_moe_gemm_2stage.py
Comment thread python/flydsl/expr/numeric.py Outdated
Comment thread kernels/mixed_moe_gemm_2stage.py
Comment thread kernels/gemm_fp8fp4_gfx1250.py
Comment thread python/flydsl/expr/math.py
Comment thread python/flydsl/expr/primitive.py
Comment thread python/flydsl/expr/primitive.py
Comment thread python/flydsl/expr/numeric.py Outdated
xudoyuan and others added 6 commits June 8, 2026 10:49
…thValue in _try_coerce_rhs

Index.__init__ skipped unwrapping when x was already an Index instance,
causing Index(Index(ArithValue)) nesting. Fix by taking .value directly
when x is Index, avoiding the nested DSL type.

Also add IndexType handling in _try_coerce_rhs: cast index-typed
ArithValue to Int32 so Numeric binary ops can process mixed i32/index
operands without falling through to ArithValue.__rmul__.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Wrapping index-typed ArithValue as Int32 caused index_cast (index→i32),
changing the IR type of the entire downstream computation chain. This
broke MLIR ops that require index-typed operands (vector.load, scf.for).

Use Index(rhs) instead to preserve the index type.

Also fix Index.__init__ to handle Index(Index) by taking .value directly,
preventing nested DSL types.

Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>

@aoli26 aoli26 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants