@@ -88,6 +88,13 @@ struct LoopSet
8888 # ref_to_sym_aliases::Dict{ArrayReference,Symbol}
8989end
9090
91+ # function op_to_ref(ls::LoopSet, op::Operation)
92+ # s = op.variable
93+ # id = findfirst(ls.syms_aliasing_regs)
94+ # @assert id !== nothing
95+ # ls.refs_aliasing_syms[id]
96+ # end
97+
9198function includesarray (ls:: LoopSet , array:: Symbol )
9299 for (a,i) ∈ ls. includedarrays
93100 a === array && return i
@@ -235,7 +242,7 @@ function add_load!(
235242 :getindex , memload, loopdependencies (ref),
236243 NODEPENDENCY, NOPARENTS, ref
237244 )
238- add_vptr! (ls, indexed , identifier (op))
245+ add_vptr! (ls, ref . array , identifier (op))
239246 pushop! (ls, op, var)
240247end
241248function add_load_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
@@ -311,12 +318,16 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
311318 else
312319 return add_operation! (ls, gensym (:temporary ), expr, elementbytes)
313320 end
314- ref = ArrayReference ( ex. args[1 + offset], @view (ex. args[2 + offset: end ]) ):: ArrayReference
321+ ref = ArrayReference (
322+ expr. args[1 + offset],
323+ @view (expr. args[2 + offset: end ]),
324+ Ref (false )
325+ ):: ArrayReference
315326 id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
316327 if id === nothing
317- add_load! ( ls, gensym (:temporary ), array, args , elementbytes )
328+ add_load! ( ls, gensym (:temporary ), ref , elementbytes )
318329 else
319- ls . syms_aliasing_refs[id]
330+ getop (ls, ls . syms_aliasing_refs[id])
320331 end
321332 # id = includesarray(ls, array)
322333 # if id > 0
@@ -371,7 +382,7 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
371382 # op = Operation( length(operations(ls)), var, elementbytes, instr, compute )
372383 reduction = false
373384 for arg ∈ args
374- if arg === var
385+ if var === arg
375386 reduction = true
376387 add_reduction! (parents, deps, reduceddeps, ls, arg, elementbytes)
377388 elseif ref == arg
@@ -389,18 +400,20 @@ function add_compute!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8,
389400 end
390401end
391402function add_store! (
392- ls:: LoopSet , indexed :: Symbol , var:: Symbol , indices :: AbstractVector , elementbytes:: Int = 8
403+ ls:: LoopSet , var:: Symbol , ref :: ArrayReference , elementbytes:: Int = 8
393404)
394405 parent = getop (ls, var)
395- op = Operation ( length (operations (ls)), indexed , elementbytes, :setindex! , memstore, indices , reduceddependencies (parent), [parent] )
396- add_vptr! (ls, indexed , identifier (op))
406+ op = Operation ( length (operations (ls)), ref . array , elementbytes, :setindex! , memstore, loopdependencies (ref) , reduceddependencies (parent), [parent], ref )
407+ add_vptr! (ls, ref . array , identifier (op))
397408 pushop! (ls, op, var)
398409end
399410function add_store_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
400- add_store! (ls, ex. args[1 ], var, @view (ex. args[2 : end ]), elementbytes)
411+ ref = ref_from_ref (ex)
412+ add_store! (ls, var, ref, elementbytes)
401413end
402414function add_store_setindex! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
403- add_store! (ls, ex. args[2 ], ex. args[3 ], @view (ex. args[4 : end ]), elementbytes)
415+ ref = ref_from_setindex (ex)
416+ add_store! (ls, var, ref, elementbytes)
404417end
405418# add operation assigns X to var
406419function add_operation! (
@@ -418,6 +431,21 @@ function add_operation!(
418431 throw (" Expression not recognized:\n $x " )
419432 end
420433end
434+ function add_operation! (
435+ ls:: LoopSet , LHS_sym:: Symbol , RHS:: Expr , LHS_ref:: ArrayReference , elementbytes:: Int = 8
436+ )
437+ if RHS. head === :ref # || (RHS.head === :call && first(RHS.args) === :getindex)
438+ add_load! (ls, LHS_sym, LHS_ref, elementbytes)
439+ elseif RHS. head === :call
440+ if first (RHS. args) === :getindex
441+ add_load! (ls, LHS_sym, LHS_ref, elementbytes)
442+ else
443+ add_compute! (ls, LHS_sym, RHS, elementbytes, LHS_ref)
444+ end
445+ else
446+ throw (" Expression not recognized:\n $x " )
447+ end
448+ end
421449function Base. push! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
422450 if ex. head === :call
423451 finex = first (ex. args):: Symbol
@@ -446,7 +474,9 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
446474 ref = ArrayReference (LHS)
447475 id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
448476 lrhs = id === nothing ? gensym (:RHS ) : ls. syms_aliasing_refs[id]
449- add_operation! (ls, lrhs, RHS, elementbytes, ref)
477+ # we pass ref, so it can compare references within RHS, and realize
478+ # they equal lrhs
479+ add_operation! (ls, lrhs, RHS, ref, elementbytes)
450480 end
451481 add_store_ref! (ls, lrhs, LHS, elementbytes)
452482 else
0 commit comments