@@ -158,7 +158,22 @@ isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
158158looprangehint (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. rangehint
159159looprangesym (ls:: LoopSet , s:: Symbol ) = ls. loops[s]. rangesym
160160# itersyms(ls::LoopSet) = keys(ls.loops)
161- getop (ls:: LoopSet , s:: Symbol ) = ls. opdict[s]
161+ function getop (ls:: LoopSet , var:: Symbol , elementbytes:: Int = 8 )
162+ get! (ls. opdict, var) do
163+ # might add constant
164+ op = add_constant! (ls, var, elementbytes)
165+ pushpreamble! (ls, Expr (:(= ), mangledvar (op), var))
166+ op
167+ end
168+ end
169+ function getop (ls:: LoopSet , var:: Symbol , deps, elementbytes:: Int = 8 )
170+ get! (ls. opdict, var) do
171+ # might add constant
172+ op = add_constant! (ls, var, deps, gensym (:constant ), elementbytes)
173+ pushpreamble! (ls, Expr (:(= ), mangledvar (op), var))
174+ op
175+ end
176+ end
162177getop (ls:: LoopSet , i:: Int ) = ls. operations[i + 1 ]
163178
164179@inline extract_val (:: Val{N} ) where {N} = N
@@ -284,7 +299,7 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
284299 if body. head === :block
285300 add_block! (ls, body, elementbytes)
286301 else
287- Base . push! (ls, q, elementbytes)
302+ push! (ls, q, elementbytes)
288303 end
289304end
290305function add_loop! (ls:: LoopSet , loop:: Loop )
@@ -316,12 +331,18 @@ function add_load!(
316331 ls:: LoopSet , var:: Symbol , ref:: ArrayReference , elementbytes:: Int = 8
317332)
318333 if ref. loaded[] == true
319- op = getop (ls, var)
334+ op = getop (ls, var, elementbytes )
320335 @assert var === op. variable
321336 return op
322337 end
323- push! (ls. syms_aliasing_refs, var)
324- push! (ls. refs_aliasing_syms, ref)
338+ id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
339+ if id === nothing
340+ push! (ls. syms_aliasing_refs, var)
341+ push! (ls. refs_aliasing_syms, ref)
342+ else
343+ opp = getop (ls, ls. syms_aliasing_refs[id], elementbytes)
344+ return isstore (opp) ? getop (ls, first (parents (opp))) : opp
345+ end
325346 ref. loaded[] = true
326347 # ls.sym_to_ref_aliases[ var ] = ref
327348 # ls.ref_to_sym_aliases[ ref ] = var
@@ -427,7 +448,7 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
427448 if id === nothing
428449 add_load! ( ls, gensym (:temporary ), ref, elementbytes )
429450 else
430- getop (ls, ls. syms_aliasing_refs[id])
451+ getop (ls, ls. syms_aliasing_refs[id], elementbytes)
431452 end
432453 # id = includesarray(ls, array)
433454 # if id > 0
@@ -440,12 +461,7 @@ function add_parent!(
440461 parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet , var, elementbytes:: Int = 8
441462)
442463 parent = if var isa Symbol
443- get! (ls. opdict, var) do
444- # might add constant
445- op = add_constant! (ls, var, elementbytes)
446- pushpreamble! (ls, Expr (:(= ), mangledvar (op), var))
447- op
448- end
464+ getop (ls, var, elementbytes)
449465 elseif var isa Expr # CSE candidate
450466 maybe_cse_load! (ls, var, elementbytes)
451467 else # assumed constant
@@ -465,7 +481,7 @@ function add_reduction_update_parent!(
465481 parents:: Vector{Operation} , deps:: Vector{Symbol} , reduceddeps:: Vector{Symbol} , ls:: LoopSet ,
466482 var:: Symbol , instr:: Symbol , elementbytes:: Int = 8
467483)
468- parent = getop (ls, var)
484+ parent = getop (ls, var, elementbytes )
469485 setdiffv! (reduceddeps, deps, loopdependencies (parent))
470486 pushparent! (parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
471487 op = Operation (length (operations (ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
@@ -502,23 +518,33 @@ end
502518function add_store! (
503519 ls:: LoopSet , var:: Symbol , ref:: ArrayReference , elementbytes:: Int = 8
504520)
505- parent = getop (ls, var)
506- op = Operation ( length (operations (ls)), ref. array, elementbytes, :setindex! , memstore, loopdependencies (ref), reduceddependencies (parent), [parent], ref )
521+ # @show loopdependencies(ref)
522+ # @show ls.operations
523+ ldref = loopdependencies (ref)
524+ parent = getop (ls, var, ldref, elementbytes)
525+ pvar = parent. variable
526+ if pvar ∉ ls. syms_aliasing_refs
527+ push! (ls. syms_aliasing_refs, pvar)
528+ push! (ls. refs_aliasing_syms, ref)
529+ end
530+ op = Operation ( length (operations (ls)), ref. array, elementbytes, :setindex! , memstore, ldref, reduceddependencies (parent), [parent], ref )
531+ # @show loopdependencies(op) op
507532 add_vptr! (ls, ref. array, identifier (op), ref. ptr)
508533 pushop! (ls, op, ref. array)
509534end
510535function add_store_ref! (ls:: LoopSet , var:: Symbol , ex:: Expr , elementbytes:: Int = 8 )
511- ref = ref_from_ref (ex)
536+ ref = ref_from_ref (ex):: ArrayReference
512537 add_store! (ls, var, ref, elementbytes)
513538end
514539function add_store_setindex! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
515- ref = ref_from_setindex (ex)
516- add_store! (ls, var , ref, elementbytes)
540+ ref = ref_from_setindex (ex):: ArrayReference
541+ add_store! (ls, (ex . args[ 2 ]) :: Symbol , ref, elementbytes)
517542end
518543# add operation assigns X to var
519544function add_operation! (
520545 ls:: LoopSet , LHS:: Symbol , RHS:: Expr , elementbytes:: Int = 8
521546)
547+ # @show LHS, RHS
522548 if RHS. head === :ref
523549 add_load_ref! (ls, LHS, RHS, elementbytes)
524550 elseif RHS. head === :call
@@ -539,11 +565,17 @@ end
539565function add_operation! (
540566 ls:: LoopSet , LHS_sym:: Symbol , RHS:: Expr , LHS_ref:: ArrayReference , elementbytes:: Int = 8
541567)
568+ # @show LHS_sym, RHS
542569 if RHS. head === :ref # || (RHS.head === :call && first(RHS.args) === :getindex)
543570 add_load! (ls, LHS_sym, LHS_ref, elementbytes)
544571 elseif RHS. head === :call
545- if first (RHS. args) === :getindex
572+ f = first (RHS. args)
573+ if f === :getindex
546574 add_load! (ls, LHS_sym, LHS_ref, elementbytes)
575+ elseif f === :zero || f === :one
576+ c = gensym (:constant )
577+ pushpreamble! (ls, Expr (:(= ), c, RHS))
578+ add_constant! (ls, c, [keys (ls. loops)... ], LHS_sym, elementbytes)
547579 else
548580 add_compute! (ls, LHS_sym, RHS, elementbytes, LHS_ref)
549581 end
@@ -552,6 +584,7 @@ function add_operation!(
552584 end
553585end
554586function Base. push! (ls:: LoopSet , ex:: Expr , elementbytes:: Int = 8 )
587+ # @show ex
555588 if ex. head === :call
556589 finex = first (ex. args):: Symbol
557590 if finex === :setindex!
@@ -566,21 +599,25 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
566599 if RHS isa Expr
567600 add_operation! (ls, LHS, RHS, elementbytes)
568601 else
602+ # @show [keys(ls.loops)...]
569603 add_constant! (ls, RHS, [keys (ls. loops)... ], LHS, elementbytes)
570604 end
571605 elseif LHS isa Expr
572606 @assert LHS. head === :ref
573607 local lrhs:: Symbol
608+ # @show LHS, RHS
574609 if RHS isa Symbol
575610 lrhs = RHS
576611 elseif RHS isa Expr
577612 # need to check of LHS appears in RHS
578613 # assign RHS to lrhs
579614 ref = ArrayReference (LHS)
580615 id = findfirst (r -> r == ref, ls. refs_aliasing_syms)
581- lrhs = id === nothing ? gensym (:RHS ) : ls. syms_aliasing_refs[id]
582- # we pass ref, so it can compare references within RHS, and realize
583- # they equal lrhs
616+ lrhs = if id === nothing
617+ gensym (:RHS )
618+ else
619+ ls. syms_aliasing_refs[id]
620+ end
584621 add_operation! (ls, lrhs, RHS, ref, elementbytes)
585622 end
586623 add_store_ref! (ls, lrhs, LHS, elementbytes)
0 commit comments