Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3e7b9fa
[Enh] Ensure type closure for primitive func
sjfeng1999 May 21, 2026
abb308a
ensure int_tuple covariance relationship
sjfeng1999 Jun 2, 2026
8287aa3
drop coerce in the derived
sjfeng1999 Jun 2, 2026
8813c9a
[Refactor]: rm IR value
xudoyuan Jun 5, 2026
a22c8ec
[Bugfix]: import __all__
xudoyuan Jun 5, 2026
e1c82ce
Merge branch 'main' into pr/enh-type-closure
xudoyuan Jun 5, 2026
9587ccd
Fix DSL type compatibility in layout_utils.idx2crd
xudoyuan Jun 5, 2026
0750258
Replace arith.constant(index=True) with fx.Int32() in mixed_moe_gemm_…
xudoyuan Jun 6, 2026
8874319
Merge branch 'main' into pr/enh-type-closure
xudoyuan Jun 7, 2026
c957349
Format mixed_moe_gemm_2stage.py with black
xudoyuan Jun 8, 2026
234240c
Restore index type for scf.ForOp arguments in mixed_moe_gemm_2stage
xudoyuan Jun 8, 2026
83469fe
Revert "Restore index type for scf.ForOp arguments in mixed_moe_gemm_…
xudoyuan Jun 8, 2026
9cdd537
Revert "Format mixed_moe_gemm_2stage.py with black"
xudoyuan Jun 8, 2026
9ef6433
Revert "Replace arith.constant(index=True) with fx.Int32() in mixed_m…
xudoyuan Jun 8, 2026
c6ebbe2
Revert global index→Int32 replacement, fix type mismatch in framework
xudoyuan Jun 8, 2026
46ed1ca
Fix Index(Index) nesting in Index.__init__ and coerce index-typed Ari…
xudoyuan Jun 8, 2026
212c837
Fix _try_coerce_rhs: wrap index ArithValue as Index instead of Int32
xudoyuan Jun 8, 2026
94ae353
minor fix
sjfeng1999 Jun 9, 2026
b2e2af6
re-export select
sjfeng1999 Jun 9, 2026
271fecd
[Bugfix]: Replace / with //
xudoyuan Jun 10, 2026
14b3bdc
[Bugfix]: Replace / with /
xudoyuan Jun 10, 2026
573a7ee
Merge branch 'main' into pr/enh-type-closure
sjfeng1999 Jun 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions kernels/blockscale_preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def kernel_gemm(
# ---- Wave / lane decomposition ----
wave_size = 64
layout_wave_lane = fx.make_layout((4, wave_size), (64, 1))
coord_wave_lane = fx.idx2crd(tx, layout_wave_lane)
coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane)
wave_id = fx.get(coord_wave_lane, 0)
lane_id = fx.get(coord_wave_lane, 1)

layout_lane16 = fx.make_layout((4, 16), (16, 1))
coord_lane16 = fx.idx2crd(lane_id, layout_lane16)
coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_lane16, 0)
lane_mod_16 = fx.get(coord_lane16, 1)

Expand Down Expand Up @@ -252,8 +252,8 @@ def load_b_packs_k64(base_k, ku: int, ni: int):
k0_base = base_k_bytes // c64_b
k0 = k0_base + ku
k1 = lane_div_16
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Int32(0))
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
b16 = _buffer_load_vec(
buffer_ops,
vector,
Expand Down
40 changes: 20 additions & 20 deletions kernels/gemm_fp8fp4_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def kernel_mxscale_gemm(
layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1))
else:
layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1))
thr_coord = idx2crd(tx, layout_thr)
thr_coord = idx2crd(fx.Int32(tx), layout_thr)
Comment thread
xudoyuan marked this conversation as resolved.
wave_m_idx, wave_n_idx, lane_kgrp, lane16 = (
fx.get(thr_coord, 0),
fx.get(thr_coord, 1),
Expand All @@ -563,12 +563,12 @@ def kernel_mxscale_gemm(
_bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False)
_bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False)
_bvs_Kt = K // tile_k # total K-tiles
_bvs_mb_a = blk_m / arith.index(128) + wave_m_idx
_bvs_mb_b = blk_n / arith.index(128) + wave_n_idx
_bvs_mb_a = blk_m // arith.index(128) + wave_m_idx
_bvs_mb_b = blk_n // arith.index(128) + wave_n_idx
_bvs_lane4 = lane16 * arith.index(4)

def _bvs_load_scales(rsrc, mb, rep, k_base):
kt = k_base / arith.index(tile_k)
kt = k_base // arith.index(tile_k)
tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128)
vals = []
for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32
Expand All @@ -594,7 +594,7 @@ def _bvs_prefetch(k_base):
).result

def make_desc_a(memref, k_base):
k_packed_off = k_base / arith.index(PACK_FACTOR_A)
k_packed_off = k_base // arith.index(PACK_FACTOR_A)
return _make_tdm_desc(
global_ptr=arg_a,
lds_memref=memref,
Expand All @@ -612,11 +612,11 @@ def make_desc_a(memref, k_base):
)

def make_desc_b(memref, k_base):
k_packed_off = k_base / arith.index(PACK_FACTOR_B)
k_packed_off = k_base // arith.index(PACK_FACTOR_B)
return _make_tdm_desc(
global_ptr=arg_b,
lds_memref=memref,
global_offset=(blk_n / arith.index(16), k_packed_off * arith.index(16)),
global_offset=(blk_n // arith.index(16), k_packed_off * arith.index(16)),
tensor_shape=(N // 16, K_packed_b * 16),
strides=(K_packed_b * 16, 1),
tile_shape=(tile_n // 16, packed_tile_k_b * 16),
Expand All @@ -631,7 +631,7 @@ def make_desc_b(memref, k_base):

def make_desc_a_half(memref, k_base, m_half: int):
row_start = m_half * ab_split_a_rows
k_packed_off = k_base / arith.index(PACK_FACTOR_A)
k_packed_off = k_base // arith.index(PACK_FACTOR_A)
return _make_tdm_desc(
global_ptr=arg_a,
lds_memref=memref,
Expand All @@ -651,11 +651,11 @@ def make_desc_a_half(memref, k_base, m_half: int):

def make_desc_b_half(memref, k_base, n_half: int):
group_start = n_half * ab_split_b_groups
k_packed_off = k_base / arith.index(PACK_FACTOR_B)
k_packed_off = k_base // arith.index(PACK_FACTOR_B)
return _make_tdm_desc(
global_ptr=arg_b,
lds_memref=memref,
global_offset=(blk_n / arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)),
global_offset=(blk_n // arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)),
tensor_shape=(N // 16, K_packed_b * 16),
strides=(K_packed_b * 16, 1),
tile_shape=(ab_split_b_groups, packed_tile_k_b * 16),
Expand All @@ -670,8 +670,8 @@ def make_desc_b_half(memref, k_base, n_half: int):
)

def make_desc_as(memref, k_base):
k_scale_off = k_base / arith.index(SCALE_BLOCK)
outer_off = blk_m / arith.index(wmma_m_rep)
k_scale_off = k_base // arith.index(SCALE_BLOCK)
outer_off = blk_m // arith.index(wmma_m_rep)
inner_off = k_scale_off * arith.index(wmma_m_rep)
return _make_tdm_desc(
global_ptr=arg_a_scale,
Expand All @@ -690,8 +690,8 @@ def make_desc_as(memref, k_base):
)

def make_desc_bs(memref, k_base):
k_scale_off = k_base / arith.index(SCALE_BLOCK)
outer_off = blk_n / arith.index(b_scale_load_rep)
k_scale_off = k_base // arith.index(SCALE_BLOCK)
outer_off = blk_n // arith.index(b_scale_load_rep)
inner_off = k_scale_off * arith.index(b_scale_load_rep)
return _make_tdm_desc(
global_ptr=arg_b_scale,
Expand Down Expand Up @@ -837,7 +837,7 @@ def load_b_frag(lds_buffer, b_lane_bases, wn, ks):

def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols):
"""Precompute scale lane bases (byte offsets)."""
warp_lds_row = warp_base / arith.index(reps) + lane16
warp_lds_row = warp_base // arith.index(reps) + lane16
base = warp_lds_row * arith.index(interleaved_cols)
if const_expr(is_fp4 or is_a8w4):
# FP4/A8W4: always add lane_kgrp offset (no opsel on BScale)
Expand Down Expand Up @@ -1985,8 +1985,8 @@ def _l2_prefetch(k_base):
if const_expr(_effective_l2_pf <= 0):
return
pf_k = k_base + arith.index(_effective_l2_pf * tile_k)
pf_k_packed_a = pf_k / arith.index(PACK_FACTOR_A)
pf_k_packed_b = pf_k / arith.index(PACK_FACTOR_B)
pf_k_packed_a = pf_k // arith.index(PACK_FACTOR_A)
pf_k_packed_b = pf_k // arith.index(PACK_FACTOR_B)
tdm_ops.l2_prefetch_tile(
arg_a,
(blk_m, pf_k_packed_a),
Expand All @@ -1998,7 +1998,7 @@ def _l2_prefetch(k_base):
)
tdm_ops.l2_prefetch_tile(
arg_b,
(blk_n / arith.index(16), pf_k_packed_b * arith.index(16)),
(blk_n // arith.index(16), pf_k_packed_b * arith.index(16)),
(tile_n // 16, packed_tile_k_b * 16),
(K_packed_b * 16, 1),
elem_bytes=1,
Expand Down Expand Up @@ -2057,9 +2057,9 @@ def _l2_prefetch(k_base):
# Match the TDM-store descriptor offsets to the compute wave mapping.
if const_expr(use_fp8_deep_pipeline_schedule):
wave_m_sgpr = wave_id_idx % arith.index(m_warp)
wave_n_sgpr = wave_id_idx / arith.index(m_warp)
wave_n_sgpr = wave_id_idx // arith.index(m_warp)
else:
wave_m_sgpr = wave_id_idx / arith.index(n_warp)
wave_m_sgpr = wave_id_idx // arith.index(n_warp)
wave_n_sgpr = wave_id_idx % arith.index(n_warp)
d_warp_linear_sgpr = wave_m_sgpr * arith.index(n_warp) + wave_n_sgpr
d_warp_off_sgpr = d_warp_linear_sgpr * arith.index(warp_d_bytes) + arith.index(d_output_off)
Expand Down
2 changes: 1 addition & 1 deletion kernels/layernorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def _load_norm_input_value(index):
mean = sum_val / n_float
var = sumsq_val / n_float - mean * mean
var = (var < c_zero_f).select(c_zero_f, var)
rstd = (var + eps_c).rsqrt(fastmath=fm_fast)
rstd = fmath.rsqrt(var + eps_c, fastmath=fm_fast)

thread_row_max = c_zero_f
for base_idx_int in range_constexpr(0, N, BLOCK_THREADS):
Expand Down
12 changes: 7 additions & 5 deletions kernels/layout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ def idx2crd(idx, layout):
"""
parsed = _parse_layout(layout)

if hasattr(idx, "ir_value"):
idx = idx.ir_value()

if parsed is None or _has_dynamic_strides(parsed[1]):
result = fx.idx2crd(idx, layout)
result = fx.idx2crd(fx.Int32(idx), layout)
ndims = len(parsed[1]) if parsed else 1
return [_wrap(fx.get(result, i)) for i in range(ndims)]

if hasattr(idx, "type") and str(idx.type) != "index":
if isinstance(idx, ir.Value) and not isinstance(idx.type, ir.IndexType):
idx = arith.index_cast(T.index, idx)
shapes, strides = parsed
ndims = len(strides)
Expand Down Expand Up @@ -156,9 +159,8 @@ def crd2idx(crd, layout):
cv = raw
crd_i32.append(cv)
coord_val = fx.make_coord(*crd_i32)
result = fx.crd2idx(coord_val, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value()
if not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return _wrap(scalar)

Expand Down
27 changes: 13 additions & 14 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@


def crd2idx(crd, layout):
"""crd2idx returning an index-type scalar (unwraps fly.int_tuple)."""
result = fx.crd2idx(crd, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = _arith.IndexCastOp(T.index, scalar).result
return scalar
"""crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple)."""
scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value()
Comment thread
sjfeng1999 marked this conversation as resolved.
if isinstance(scalar.type, ir.IndexType):
return scalar
return _arith.IndexCastOp(T.index, scalar).result


def swizzle_xor16(row, col, k_blocks16):
Expand Down Expand Up @@ -326,7 +325,7 @@ def load_b_raw_w4a16(
k2_base = lane_odd * fx.Index(half_bytes)

coord_pack = (n_blk, k0, k1_local, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
idx_bytes = idx_pack + k2_base

b4 = _buffer_load_vec(
Expand Down Expand Up @@ -464,7 +463,7 @@ def load_b_pack_k32(
k2_base = arith.constant((ki_step % 2) * half_bytes, index=True)

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)

if unpack_int4:
idx_bytes = idx_pack + k2_base
Expand Down Expand Up @@ -527,7 +526,7 @@ def tile_chunk_coord_i32(
raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}")
chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True)
tile_idx_i32 = tx_i32_base + chunk_off_i32
coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4)
coord_local = fx.idx2crd(fx.Int32(tile_idx_i32), layout_tile_div4)
row_local = fx.get(coord_local, 0)
col_local_i32 = fx.get(coord_local, 1)
return row_local, col_local_i32
Expand Down Expand Up @@ -580,7 +579,7 @@ def lds_store_16b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v16 = vector.bitcast(vec16_ty, vec_part_i32x4)
vector.store(v16, lds_memref, [idx0])

Expand All @@ -607,7 +606,7 @@ def lds_store_8b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v8 = vector.bitcast(vec8_ty, vec_part_i32x2)
vector.store(v8, lds_memref, [idx0])

Expand All @@ -634,7 +633,7 @@ def lds_store_4b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v4 = vector.bitcast(vec4_ty, vec_part_i32x1)
vector.store(v4, lds_memref, [idx0])

Expand All @@ -660,14 +659,14 @@ def lds_load_pack_k32(
col_base_swz = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
if ck_lds128:
coord_a16 = (curr_row_a_lds, col_base_swz)
idx_a16 = crd2idx(coord_a16, layout_lds) + lds_base
idx_a16 = crd2idx(tuple(fx.Int32(c) for c in coord_a16), layout_lds) + lds_base
loaded_a16 = vector.load_op(vec16_ty, lds_memref, [idx_a16])
a_vec128 = vector.bitcast(vec2_i64_ty, loaded_a16)
return vector.extract(a_vec128, static_position=[half], dynamic_position=[])
else:
col_swizzled = col_base_swz + (half * 8)
coord_a = (curr_row_a_lds, col_swizzled)
idx_a = crd2idx(coord_a, layout_lds) + lds_base
idx_a = crd2idx(tuple(fx.Int32(c) for c in coord_a), layout_lds) + lds_base
loaded_a8 = vector.load_op(vec8_ty, lds_memref, [idx_a])
a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8)
return vector.extract(a_vec64, static_position=[0], dynamic_position=[])
Expand Down
18 changes: 9 additions & 9 deletions kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,10 +729,10 @@ def load_x_tile(base_k):
return parts

# Wave/lane decomposition (identical to stage2)
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)
row_a_lds = lane_mod_16
Expand Down Expand Up @@ -763,12 +763,12 @@ def load_x_tile(base_k):
global_n = by_n + n_tile_base + c_offset + lane_mod_16
# Gate/interleave: rows [expert_off, expert_off + 2*inter_dim)
gate_row_w = expert_off_idx + global_n
gate_coord = idx2crd(gate_row_w, layout_n_blk_intra)
gate_coord = idx2crd(fx.Int32(gate_row_w), layout_n_blk_intra)
gate_n_blk_list.append(layout_get(gate_coord, 0))
gate_n_intra_list.append(layout_get(gate_coord, 1))
if const_expr(not mock_gate_only and not gate_up_interleave):
up_row_w = gate_row_w + inter_idx
up_coord = idx2crd(up_row_w, layout_n_blk_intra)
up_coord = idx2crd(fx.Int32(up_row_w), layout_n_blk_intra)
up_n_blk_list.append(layout_get(up_coord, 0))
up_n_intra_list.append(layout_get(up_coord, 1))

Expand Down Expand Up @@ -799,7 +799,7 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra):
k0 = base_k_bytes // c64 + arith.constant(ku, index=True)
k1 = lane_div_16
coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
vec_elems = kpack_bytes // int(b_elem_bytes)
b16 = _buffer_load_vec(
buffer_ops,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def prefetch_x_to_lds(base_k, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
Comment thread
sjfeng1999 marked this conversation as resolved.
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down Expand Up @@ -3074,10 +3074,10 @@ def load_x_tile(base_k):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)

Expand Down Expand Up @@ -3330,7 +3330,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
Comment thread
sjfeng1999 marked this conversation as resolved.
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down
Loading
Loading