-
Notifications
You must be signed in to change notification settings - Fork 78
Fine grained compute transform #970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fine grained compute transform #970
Conversation
inducer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some thoughts from a first scroll through this code.
| ) | ||
|
|
||
| # FIXME: likely need to do some more digging to find proper domain to update | ||
| new_domain = kernel.domains[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the DomainChanger.
| def compute( | ||
| t_unit: TranslationUnit, | ||
| substitution: str, | ||
| *args, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just repeat everything from below. basedpyright will help ensure consistency between both copies.
| from pytools.tag import Tag | ||
|
|
||
|
|
||
| def compute( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this could just be a @for_each_kernel decorator?
| ) -> LoopKernel: | ||
| """ | ||
| Inserts an instruction to compute an expression given by :arg:`substitution` | ||
| and replaces all invocations of :arg:`substitution` with the result of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all
Not really: only the relevant ones, where relevant should be defined below.
|
|
||
| return t_unit | ||
|
|
||
| def _compute_inner( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, this would've been an opportunity to try out namedisl, at least locally, to see how the interface "feels".
| substitution: str, | ||
| transform_map: isl.Map, | ||
| compute_map: isl.Map, | ||
| storage_inames: list[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is storage_inames?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed this the first time around. storage_inames corresponds to the inames that would be generated in something like a tiled matmul to fill shared memory with input tiles. Maybe storage_axes is a better name. This corresponds to the a in the (a, l) range of compute_map.
| substitution: str, | ||
| transform_map: isl.Map, | ||
| compute_map: isl.Map, | ||
| storage_inames: list[str], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefer Sequence to list on input.
|
|
||
| # {{{ express substitution inputs as pw affs of (storage, time) names | ||
|
|
||
| compute_pw_aff = compute_map.reverse().as_pw_multi_aff() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does .as_multi_pw_aff() do in this context? I've never used it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this particular instance, it expresses substitution inputs as piecewise affine functions of (a, l). compute uses the output of the resulting PwMultiAff to determine the multidimensional index expressions of the RHS of a substitution rule.
| rule_mapping_ctx = SubstitutionRuleMappingContext( | ||
| kernel.substitutions, kernel.get_var_name_generator()) | ||
|
|
||
| expr_subst_map = RuleAwareSubstitutionMapper( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to subclass this guy. Otherwise you won't be able to decide whether the usage site is "in-footprint".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding was that RuleAwareSubstitutionMapper only mapped storage axes to index expressions in pymbolic.
Do you mean RuleInvocationReplacer? This explicitly checks footprints with ArrayToBufferMap and some other information computed earlier in precompute.
| new_insns = [compute_insn] | ||
|
|
||
| # add global sync if we are storing in global memory | ||
| if temporary_address_space == lp.AddressSpace.GLOBAL: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is global special-cased here? There's also sometimes a need to insert local barriers. And whether that's needed is a function of the dependency structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Snagged this from existing precompute during development. This should have been taken out. I agree that we should rely on dependency checking for barrier insertion.
Adds a transformation to compute an expression given by a substitution rule using polyhedral maps and replace invocations of that substitution rule with the result of the computation.
Depends on #916
Replaces
precompute: https://github.com/inducer/loopy/blob/main/loopy/transform/precompute.py#L387-L1170