-
Notifications
You must be signed in to change notification settings - Fork 78
Fine grained compute transform #970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,196 @@ | ||
| import islpy as isl | ||
|
|
||
| import loopy as lp | ||
| from loopy.kernel import LoopKernel | ||
| from loopy.kernel.data import AddressSpace | ||
| from loopy.kernel.function_interface import CallableKernel, ScalarCallable | ||
| from loopy.match import parse_stack_match | ||
| from loopy.symbolic import ( | ||
| RuleAwareSubstitutionMapper, | ||
| SubstitutionRuleMappingContext, | ||
| pw_aff_to_expr | ||
| ) | ||
| from loopy.translation_unit import TranslationUnit | ||
|
|
||
| from pymbolic import var | ||
| from pymbolic.mapper.substitutor import make_subst_func | ||
|
|
||
| from pytools.tag import Tag | ||
|
|
||
|
|
||
| def compute( | ||
| t_unit: TranslationUnit, | ||
| substitution: str, | ||
| *args, | ||
|
Check warning on line 24 in loopy/transform/compute.py
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just repeat everything from below. basedpyright will help ensure consistency between both copies. |
||
| **kwargs | ||
|
Check warning on line 25 in loopy/transform/compute.py
|
||
| ) -> TranslationUnit: | ||
| """ | ||
| Entrypoint for performing a compute transformation on all kernels in a | ||
| translation unit. See :func:`_compute_inner` for more details. | ||
| """ | ||
|
|
||
| assert isinstance(t_unit, TranslationUnit) | ||
| new_callables = {} | ||
|
|
||
| for id, callable in t_unit.callables_table.items(): | ||
| if isinstance(callable, CallableKernel): | ||
| kernel = _compute_inner( | ||
| callable.subkernel, | ||
| substitution, | ||
| *args, **kwargs | ||
|
Check warning on line 40 in loopy/transform/compute.py
|
||
| ) | ||
|
|
||
| callable = callable.copy(subkernel=kernel) | ||
| elif isinstance(callable, ScalarCallable): | ||
| pass | ||
| else: | ||
| raise NotImplementedError() | ||
|
|
||
| new_callables[id] = callable | ||
|
|
||
| return t_unit | ||
|
|
||
| def _compute_inner( | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, this would've been an opportunity to try out namedisl, at least locally, to see how the interface "feels". |
||
| kernel: LoopKernel, | ||
| substitution: str, | ||
| transform_map: isl.Map, | ||
| compute_map: isl.Map, | ||
| storage_inames: list[str], | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missed this the first time around.
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer |
||
| default_tag: Tag | str | None = None, | ||
| temporary_address_space: AddressSpace | None = None | ||
| ) -> LoopKernel: | ||
| """ | ||
| Inserts an instruction to compute an expression given by :arg:`substitution` | ||
| and replaces all invocations of :arg:`substitution` with the result of the | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Not really: only the relevant ones, where relevant should be defined below. |
||
| compute instruction. | ||
|
|
||
| :arg substitution: The substitution rule for which the compute | ||
| transform should be applied. | ||
|
|
||
| :arg transform_map: An :class:`isl.Map` representing the affine | ||
| transformation from the original iname domain to the transformed iname | ||
| domain. | ||
|
|
||
| :arg compute_map: An :class:`isl.Map` representing a relation between | ||
| substitution rule indices and tuples `(a, l)`, where `a` is a vector of | ||
| storage indices and `l` is a vector of "timestamps". | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How is the boundary of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In it's current form, it relies on user input to determine what |
||
| """ | ||
|
|
||
| if not temporary_address_space: | ||
| temporary_address_space = AddressSpace.GLOBAL | ||
|
|
||
| # {{{ normalize names | ||
|
|
||
| iname_to_storage_map = { | ||
| iname : (iname + "_store" if iname in kernel.all_inames() else iname) | ||
| for iname in storage_inames | ||
| } | ||
|
|
||
| new_storage_axes = list(iname_to_storage_map.values()) | ||
|
|
||
| for dim in range(compute_map.dim(isl.dim_type.out)): | ||
| for iname, storage_ax in iname_to_storage_map.items(): | ||
| if compute_map.get_dim_name(isl.dim_type.out, dim) == iname: | ||
| compute_map = compute_map.set_dim_name( | ||
| isl.dim_type.out, dim, storage_ax | ||
| ) | ||
|
|
||
| # }}} | ||
|
|
||
| # {{{ update kernel domain to contain storage inames | ||
|
|
||
| storage_domain = compute_map.range().project_out_except( | ||
| new_storage_axes, [isl.dim_type.set] | ||
| ) | ||
|
|
||
| # FIXME: likely need to do some more digging to find proper domain to update | ||
| new_domain = kernel.domains[0] | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the |
||
| for ax in new_storage_axes: | ||
| new_domain = new_domain.add_dims(isl.dim_type.set, 1) | ||
|
|
||
| new_domain = new_domain.set_dim_name( | ||
| isl.dim_type.set, | ||
| new_domain.dim(isl.dim_type.set) - 1, | ||
| ax | ||
| ) | ||
|
|
||
| new_domain, storage_domain = isl.align_two(new_domain, storage_domain) | ||
| new_domain = new_domain & storage_domain | ||
| kernel = kernel.copy(domains=[new_domain]) | ||
|
|
||
| # }}} | ||
|
|
||
| # {{{ express substitution inputs as pw affs of (storage, time) names | ||
|
|
||
| compute_pw_aff = compute_map.reverse().as_pw_multi_aff() | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What does
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this particular instance, it expresses substitution inputs as piecewise affine functions of |
||
| subst_inp_names = [ | ||
| compute_map.get_dim_name(isl.dim_type.in_, i) | ||
| for i in range(compute_map.dim(isl.dim_type.in_)) | ||
| ] | ||
| storage_ax_to_global_expr = dict.fromkeys(subst_inp_names) | ||
| for dim in range(compute_pw_aff.dim(isl.dim_type.out)): | ||
| subst_inp = compute_map.get_dim_name(isl.dim_type.in_, dim) | ||
| storage_ax_to_global_expr[subst_inp] = \ | ||
| pw_aff_to_expr(compute_pw_aff.get_at(dim)) | ||
|
|
||
| # }}} | ||
|
|
||
| # {{{ generate instruction from compute map | ||
|
|
||
| rule_mapping_ctx = SubstitutionRuleMappingContext( | ||
| kernel.substitutions, kernel.get_var_name_generator()) | ||
|
|
||
| expr_subst_map = RuleAwareSubstitutionMapper( | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll need to subclass this guy. Otherwise you won't be able to decide whether the usage site is "in-footprint".
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My understanding was that Do you mean |
||
| rule_mapping_ctx, | ||
| make_subst_func(storage_ax_to_global_expr), | ||
|
Check failure on line 145 in loopy/transform/compute.py
|
||
| within=parse_stack_match(None) | ||
| ) | ||
|
|
||
| subst_expr = kernel.substitutions[substitution].expression | ||
| compute_expression = expr_subst_map(subst_expr, kernel, None) | ||
|
Check failure on line 150 in loopy/transform/compute.py
|
||
|
|
||
| temporary_name = substitution + "_temp" | ||
| assignee = var(temporary_name)[tuple( | ||
| var(iname) for iname in new_storage_axes | ||
| )] | ||
|
|
||
| compute_insn_id = substitution + "_compute" | ||
| compute_insn = lp.Assignment( | ||
| id=compute_insn_id, | ||
| assignee=assignee, | ||
| expression=compute_expression, | ||
| ) | ||
|
|
||
| compute_dep_id = compute_insn_id | ||
| new_insns = [compute_insn] | ||
|
|
||
| # add global sync if we are storing in global memory | ||
| if temporary_address_space == lp.AddressSpace.GLOBAL: | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is global special-cased here? There's also sometimes a need to insert local barriers. And whether that's needed is a function of the dependency structure.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Snagged this from existing |
||
| gbarrier_id = kernel.make_unique_instruction_id( | ||
| based_on=substitution + "_barrier" | ||
| ) | ||
|
|
||
| from loopy.kernel.instruction import BarrierInstruction | ||
| barrier_insn = BarrierInstruction( | ||
| id=gbarrier_id, | ||
| depends_on=frozenset([compute_insn_id]), | ||
| synchronization_kind="global", | ||
| mem_kind="global" | ||
| ) | ||
|
|
||
| compute_dep_id = gbarrier_id | ||
|
|
||
| # }}} | ||
|
|
||
| # {{{ replace substitution rule with newly created instruction | ||
|
|
||
| # FIXME: get these properly (see `precompute`) | ||
| subst_name = substitution | ||
| subst_tag = None | ||
| within = None # do we need this? | ||
|
|
||
|
|
||
|
|
||
| # }}} | ||
|
|
||
| return kernel | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this could just be a
@for_each_kerneldecorator?