Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions loopy/transform/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
import islpy as isl

import loopy as lp
from loopy.kernel import LoopKernel
from loopy.kernel.data import AddressSpace
from loopy.kernel.function_interface import CallableKernel, ScalarCallable
from loopy.match import parse_stack_match
from loopy.symbolic import (
RuleAwareSubstitutionMapper,
SubstitutionRuleMappingContext,
pw_aff_to_expr
)
from loopy.translation_unit import TranslationUnit

from pymbolic import var
from pymbolic.mapper.substitutor import make_subst_func

from pytools.tag import Tag


def compute(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this could just be a @for_each_kernel decorator?

t_unit: TranslationUnit,
substitution: str,
*args,

Check warning on line 24 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "args" (reportMissingParameterType)

Check warning on line 24 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "args" is unknown (reportUnknownParameterType)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just repeat everything from below. basedpyright will help ensure consistency between both copies.

**kwargs

Check warning on line 25 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type annotation is missing for parameter "kwargs" (reportMissingParameterType)

Check warning on line 25 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Type of parameter "kwargs" is unknown (reportUnknownParameterType)
) -> TranslationUnit:
"""
Entrypoint for performing a compute transformation on all kernels in a
translation unit. See :func:`_compute_inner` for more details.
"""

assert isinstance(t_unit, TranslationUnit)
new_callables = {}

for id, callable in t_unit.callables_table.items():
if isinstance(callable, CallableKernel):
kernel = _compute_inner(
callable.subkernel,
substitution,
*args, **kwargs

Check warning on line 40 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "temporary_address_space" in function "_compute_inner" (reportUnknownArgumentType)

Check warning on line 40 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "default_tag" in function "_compute_inner" (reportUnknownArgumentType)

Check warning on line 40 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "storage_inames" in function "_compute_inner" (reportUnknownArgumentType)

Check warning on line 40 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "compute_map" in function "_compute_inner" (reportUnknownArgumentType)

Check warning on line 40 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument type is unknown   Argument corresponds to parameter "transform_map" in function "_compute_inner" (reportUnknownArgumentType)
)

callable = callable.copy(subkernel=kernel)
elif isinstance(callable, ScalarCallable):
pass
else:
raise NotImplementedError()

new_callables[id] = callable

return t_unit

def _compute_inner(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this would've been an opportunity to try out namedisl, at least locally, to see how the interface "feels".

kernel: LoopKernel,
substitution: str,
transform_map: isl.Map,
compute_map: isl.Map,
storage_inames: list[str],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is storage_inames?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed this the first time around. storage_inames corresponds to the inames that would be generated in something like a tiled matmul to fill shared memory with input tiles. Maybe storage_axes is a better name. This corresponds to the a in the (a, l) range of compute_map.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer Sequence to list on input.

default_tag: Tag | str | None = None,
temporary_address_space: AddressSpace | None = None
) -> LoopKernel:
"""
Inserts an instruction to compute an expression given by :arg:`substitution`
and replaces all invocations of :arg:`substitution` with the result of the
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all

Not really: only the relevant ones, where relevant should be defined below.

compute instruction.

:arg substitution: The substitution rule for which the compute
transform should be applied.

:arg transform_map: An :class:`isl.Map` representing the affine
transformation from the original iname domain to the transformed iname
domain.

:arg compute_map: An :class:`isl.Map` representing a relation between
substitution rule indices and tuples `(a, l)`, where `a` is a vector of
storage indices and `l` is a vector of "timestamps".
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is the boundary of a and l determined?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In it's current form, it relies on user input to determine what a is (this is storage_inames).

"""

if not temporary_address_space:
temporary_address_space = AddressSpace.GLOBAL

# {{{ normalize names

iname_to_storage_map = {
iname : (iname + "_store" if iname in kernel.all_inames() else iname)
for iname in storage_inames
}

new_storage_axes = list(iname_to_storage_map.values())

for dim in range(compute_map.dim(isl.dim_type.out)):
for iname, storage_ax in iname_to_storage_map.items():
if compute_map.get_dim_name(isl.dim_type.out, dim) == iname:
compute_map = compute_map.set_dim_name(
isl.dim_type.out, dim, storage_ax
)

# }}}

# {{{ update kernel domain to contain storage inames

storage_domain = compute_map.range().project_out_except(
new_storage_axes, [isl.dim_type.set]
)

# FIXME: likely need to do some more digging to find proper domain to update
new_domain = kernel.domains[0]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the DomainChanger.

for ax in new_storage_axes:
new_domain = new_domain.add_dims(isl.dim_type.set, 1)

new_domain = new_domain.set_dim_name(
isl.dim_type.set,
new_domain.dim(isl.dim_type.set) - 1,
ax
)

new_domain, storage_domain = isl.align_two(new_domain, storage_domain)
new_domain = new_domain & storage_domain
kernel = kernel.copy(domains=[new_domain])

# }}}

# {{{ express substitution inputs as pw affs of (storage, time) names

compute_pw_aff = compute_map.reverse().as_pw_multi_aff()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does .as_multi_pw_aff() do in this context? I've never used it.

Copy link
Contributor Author

@a-alveyblanc a-alveyblanc Nov 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this particular instance, it expresses substitution inputs as piecewise affine functions of (a, l). compute uses the output of the resulting PwMultiAff to determine the multidimensional index expressions of the RHS of a substitution rule.

subst_inp_names = [
compute_map.get_dim_name(isl.dim_type.in_, i)
for i in range(compute_map.dim(isl.dim_type.in_))
]
storage_ax_to_global_expr = dict.fromkeys(subst_inp_names)
for dim in range(compute_pw_aff.dim(isl.dim_type.out)):
subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim)
storage_ax_to_global_expr[subst_inp] = \
pw_aff_to_expr(compute_pw_aff.get_at(dim))

# }}}

# {{{ generate instruction from compute map

rule_mapping_ctx = SubstitutionRuleMappingContext(
kernel.substitutions, kernel.get_var_name_generator())

expr_subst_map = RuleAwareSubstitutionMapper(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to subclass this guy. Otherwise you won't be able to decide whether the usage site is "in-footprint".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding was that RuleAwareSubstitutionMapper only mapped storage axes to index expressions in pymbolic.

Do you mean RuleInvocationReplacer? This explicitly checks footprints with ArrayToBufferMap and some other information computed earlier in precompute.

rule_mapping_ctx,
make_subst_func(storage_ax_to_global_expr),

Check failure on line 145 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "dict[str | None, Any | None]" cannot be assigned to parameter "variable_assignments" of type "CanGetitem[Any, Expression]" in function "make_subst_func"   "dict[str | None, Any | None]" is incompatible with protocol "CanGetitem[Any, Expression]"     "__getitem__" is an incompatible type       Type "(key: str | None, /) -> (Any | None)" is not assignable to type "(key: _K_contra@CanGetitem, /) -> _V_co@CanGetitem"         Function return type "Any | None" is incompatible with type "_V_co@CanGetitem"           Type "Any | None" is not assignable to type "Expression" (reportArgumentType)
within=parse_stack_match(None)
)

subst_expr = kernel.substitutions[substitution].expression
compute_expression = expr_subst_map(subst_expr, kernel, None)

Check failure on line 150 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Argument of type "None" cannot be assigned to parameter "insn" of type "InstructionBase" in function "__call__"   "None" is not assignable to "InstructionBase" (reportArgumentType)

temporary_name = substitution + "_temp"
assignee = var(temporary_name)[tuple(
var(iname) for iname in new_storage_axes
)]

compute_insn_id = substitution + "_compute"
compute_insn = lp.Assignment(
id=compute_insn_id,
assignee=assignee,
expression=compute_expression,
)

compute_dep_id = compute_insn_id

Check warning on line 164 in loopy/transform/compute.py

View workflow job for this annotation

GitHub Actions / basedpyright

Variable "compute_dep_id" is not accessed (reportUnusedVariable)
new_insns = [compute_insn]

# add global sync if we are storing in global memory
if temporary_address_space == lp.AddressSpace.GLOBAL:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is global special-cased here? There's also sometimes a need to insert local barriers. And whether that's needed is a function of the dependency structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Snagged this from existing precompute during development. This should have been taken out. I agree that we should rely on dependency checking for barrier insertion.

gbarrier_id = kernel.make_unique_instruction_id(
based_on=substitution + "_barrier"
)

from loopy.kernel.instruction import BarrierInstruction
barrier_insn = BarrierInstruction(
id=gbarrier_id,
depends_on=frozenset([compute_insn_id]),
synchronization_kind="global",
mem_kind="global"
)

compute_dep_id = gbarrier_id

# }}}

# {{{ replace substitution rule with newly created instruction

# FIXME: get these properly (see `precompute`)
subst_name = substitution
subst_tag = None
within = None # do we need this?



# }}}

return kernel
Loading