@@ -2115,4 +2115,64 @@ def get_outer_params(domains):
21152115# }}}
21162116
21172117
2118+ # {{{ get access map from an instruction
2119+
2120+ class _IndexCollector (CombineMapper ):
2121+ def __init__ (self , var ):
2122+ self .var = var
2123+ super ().__init__ ()
2124+
2125+ def combine (self , values ):
2126+ import operator
2127+ return reduce (operator .or_ , values , frozenset ())
2128+
2129+ def map_subscript (self , expr ):
2130+ if expr .aggregate .name == self .var :
2131+ return (super ().map_subscript (expr ) | frozenset ([expr .index_tuple ]))
2132+ else :
2133+ return super ().map_subscript (expr )
2134+
2135+ def map_algebraic_leaf (self , expr ):
2136+ return frozenset ()
2137+
2138+ map_constant = map_algebraic_leaf
2139+
2140+
2141+ def _project_out_inames_from_maps (amaps , inames_to_project_out ):
2142+ new_amaps = []
2143+ for amap in amaps :
2144+ for iname in inames_to_project_out :
2145+ dt , pos = amap .get_var_dict ()[iname ]
2146+ amap = amap .project_out (dt , pos , 1 )
2147+
2148+ new_amaps .append (amap )
2149+
2150+ return new_amaps
2151+
2152+
2153+ def _union_amaps (amaps ):
2154+ import islpy as isl
2155+ return reduce (isl .Map .union , amaps [1 :], amaps [0 ])
2156+
2157+
2158+ def get_insn_access_map (kernel , insn_id , var ):
2159+ from loopy .transform .subst import expand_subst
2160+ from loopy .match import Id
2161+ from loopy .symbolic import get_access_map
2162+
2163+ insn = kernel .id_to_insn [insn_id ]
2164+
2165+ kernel = expand_subst (kernel , within = Id (insn_id ))
2166+ indices = list (_IndexCollector (var )((insn .expression ,
2167+ insn .assignees ,
2168+ tuple (insn .predicates ))))
2169+
2170+ amaps = [get_access_map (kernel .get_inames_domain (insn .within_inames ),
2171+ idx , kernel .assumptions )
2172+ for idx in indices ]
2173+
2174+ return _union_amaps (amaps )
2175+
2176+ # }}}
2177+
21182178# vim: foldmethod=marker
0 commit comments