Skip to content

Commit 23be763

Browse files
adds a utility to get the instruction access map
Co-authored-by: Matthias Diener <[email protected]>
1 parent c21c092 commit 23be763

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

loopy/kernel/tools.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,4 +2115,64 @@ def get_outer_params(domains):
21152115
# }}}
21162116

21172117

2118+
# {{{ get access map from an instruction
2119+
2120+
class _IndexCollector(CombineMapper):
2121+
def __init__(self, var):
2122+
self.var = var
2123+
super().__init__()
2124+
2125+
def combine(self, values):
2126+
import operator
2127+
return reduce(operator.or_, values, frozenset())
2128+
2129+
def map_subscript(self, expr):
2130+
if expr.aggregate.name == self.var:
2131+
return (super().map_subscript(expr) | frozenset([expr.index_tuple]))
2132+
else:
2133+
return super().map_subscript(expr)
2134+
2135+
def map_algebraic_leaf(self, expr):
2136+
return frozenset()
2137+
2138+
map_constant = map_algebraic_leaf
2139+
2140+
2141+
def _project_out_inames_from_maps(amaps, inames_to_project_out):
2142+
new_amaps = []
2143+
for amap in amaps:
2144+
for iname in inames_to_project_out:
2145+
dt, pos = amap.get_var_dict()[iname]
2146+
amap = amap.project_out(dt, pos, 1)
2147+
2148+
new_amaps.append(amap)
2149+
2150+
return new_amaps
2151+
2152+
2153+
def _union_amaps(amaps):
2154+
import islpy as isl
2155+
return reduce(isl.Map.union, amaps[1:], amaps[0])
2156+
2157+
2158+
def get_insn_access_map(kernel, insn_id, var):
2159+
from loopy.transform.subst import expand_subst
2160+
from loopy.match import Id
2161+
from loopy.symbolic import get_access_map
2162+
2163+
insn = kernel.id_to_insn[insn_id]
2164+
2165+
kernel = expand_subst(kernel, within=Id(insn_id))
2166+
indices = list(_IndexCollector(var)((insn.expression,
2167+
insn.assignees,
2168+
tuple(insn.predicates))))
2169+
2170+
amaps = [get_access_map(kernel.get_inames_domain(insn.within_inames),
2171+
idx, kernel.assumptions)
2172+
for idx in indices]
2173+
2174+
return _union_amaps(amaps)
2175+
2176+
# }}}
2177+
21182178
# vim: foldmethod=marker

0 commit comments

Comments
 (0)