[Enh] Improve type closure for primitive func#552
Conversation
|
may need update the same aiter kernels simultaneously. |
There was a problem hiding this comment.
Pull request overview
This PR updates FlyDSL’s Python expression-layer wrappers to better preserve (“close over”) DSL types (e.g., Numeric, Vector) when calling primitive ops, and adjusts affected tests/kernels to use the new return-value behavior.
Changes:
- Wrap primitive-op scalar results back into
Numerictypes (and propagate this through helpers likeget_scalar,get_leaves,ptr_load,memref_load). - Centralize
memref_load_vecto return aVectorwith shape/dtype metadata, and simplifyTensor.load()accordingly. - Update unit tests and a few kernel utilities to align with the revised scalar/vector return types and pipeline string formatting.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/unit/test_static_vs_dynamic.py | Adjusts dynamic-layout tests to return i32 scalars directly and simplifies pipeline string formatting. |
| tests/unit/test_layout_algebra.py | Updates dynamic test functions to return i32 scalars from fx.get_scalar(...).ir_value() and simplifies pipeline string formatting. |
| python/flydsl/expr/typing.py | Updates IntTuple reconstruction and makes Tensor.load() rely on the updated memref_load_vec wrapper. |
| python/flydsl/expr/primitive.py | Introduces numeric re-wrapping helper and applies it across several primitive ops; moves vector wrapping into memref_load_vec. |
| python/flydsl/expr/math.py | Extends traced math-op wrapping to preserve DSL closure for both Numeric and Vector inputs. |
| kernels/silu_and_mul_fq.py | Simplifies scale-offset computation by relying on the updated fx.get_scalar behavior. |
| kernels/mfma_preshuffle_pipeline.py | Updates crd2idx helper to unwrap int-tuples and cast to index type using the new scalar typing behavior. |
| kernels/layout_utils.py | Updates dynamic-layout crd2idx fallback to unwrap/cast through fx.get_scalar(...).ir_value(). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
CI failed. @sjfeng1999 |
205ec38 to
abb308a
Compare
Unwrap DSL types (Int32, etc.) to raw ir.Value at the entry of idx2crd, so that downstream arith ops (shrui, andi) receive proper ir.Value operands instead of Numeric wrappers. Also replace fragile hasattr/str-based type check with isinstance. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…2stage Eliminate all index-typed arith.constant and arith.index calls in mixed_moe_gemm_2stage.py, replacing them with fx.Int32() DSL types. This prevents type mismatch errors (i32 vs index) when DSL-typed values (from fx.get, fx.idx2crd) participate in arithmetic with index-typed constants. 174 occurrences replaced. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
scf.ForOp requires index-typed ir.Value for lower_bound, upper_bound, step, and iter_args. These were incorrectly converted to fx.Int32() in the previous commit. Restore them to arith.constant(..., index=True). Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
…2stage" This reverts commit 234240c.
This reverts commit c957349.
…oe_gemm_2stage" This reverts commit 0750258.
Revert the bulk arith.constant(index=True)→fx.Int32() replacement in mixed_moe_gemm_2stage.py — the file mixes index and i32 type chains that cannot be uniformly converted. Instead, fix the root cause in _try_coerce_rhs (numeric.py): when Numeric.__mul__ encounters an index-typed ArithValue as the rhs operand, cast it to Int32 so the operands have matching types. This resolves the arith.muli(i32, index) type mismatch without changing any kernel code. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
jhinpan
left a comment
There was a problem hiding this comment.
I found two issues that still apply on the current head. The earlier mixed_moe index/i32 arithmetic concern appears fixed in this version: sorted_m is back to an index constant path.
…thValue in _try_coerce_rhs Index.__init__ skipped unwrapping when x was already an Index instance, causing Index(Index(ArithValue)) nesting. Fix by taking .value directly when x is Index, avoiding the nested DSL type. Also add IndexType handling in _try_coerce_rhs: cast index-typed ArithValue to Int32 so Numeric binary ops can process mixed i32/index operands without falling through to ArithValue.__rmul__. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Wrapping index-typed ArithValue as Int32 caused index_cast (index→i32), changing the IR type of the entire downstream computation chain. This broke MLIR ops that require index-typed operands (vector.load, scf.for). Use Index(rhs) instead to preserve the index type. Also fix Index.__init__ to handle Index(Index) by taking .value directly, preventing nested DSL types. Co-Authored-By: Claude Opus 4 <noreply@anthropic.com>
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist