Skip to content

Commit 8ed082c

Browse files
committed
Fix #15.
1 parent da4139d commit 8ed082c

File tree

4 files changed

+112
-44
lines changed

4 files changed

+112
-44
lines changed

src/graphs.jl

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,22 @@ isstaticloop(ls::LoopSet, s::Symbol) = ls.loops[s].hintexact
158158
looprangehint(ls::LoopSet, s::Symbol) = ls.loops[s].rangehint
159159
looprangesym(ls::LoopSet, s::Symbol) = ls.loops[s].rangesym
160160
# itersyms(ls::LoopSet) = keys(ls.loops)
161-
getop(ls::LoopSet, s::Symbol) = ls.opdict[s]
161+
function getop(ls::LoopSet, var::Symbol, elementbytes::Int = 8)
162+
get!(ls.opdict, var) do
163+
# might add constant
164+
op = add_constant!(ls, var, elementbytes)
165+
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
166+
op
167+
end
168+
end
169+
function getop(ls::LoopSet, var::Symbol, deps, elementbytes::Int = 8)
170+
get!(ls.opdict, var) do
171+
# might add constant
172+
op = add_constant!(ls, var, deps, gensym(:constant), elementbytes)
173+
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
174+
op
175+
end
176+
end
162177
getop(ls::LoopSet, i::Int) = ls.operations[i + 1]
163178

164179
@inline extract_val(::Val{N}) where {N} = N
@@ -284,7 +299,7 @@ function add_loop!(ls::LoopSet, q::Expr, elementbytes::Int = 8)
284299
if body.head === :block
285300
add_block!(ls, body, elementbytes)
286301
else
287-
Base.push!(ls, q, elementbytes)
302+
push!(ls, q, elementbytes)
288303
end
289304
end
290305
function add_loop!(ls::LoopSet, loop::Loop)
@@ -316,12 +331,18 @@ function add_load!(
316331
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
317332
)
318333
if ref.loaded[] == true
319-
op = getop(ls, var)
334+
op = getop(ls, var, elementbytes)
320335
@assert var === op.variable
321336
return op
322337
end
323-
push!(ls.syms_aliasing_refs, var)
324-
push!(ls.refs_aliasing_syms, ref)
338+
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
339+
if id === nothing
340+
push!(ls.syms_aliasing_refs, var)
341+
push!(ls.refs_aliasing_syms, ref)
342+
else
343+
opp = getop(ls, ls.syms_aliasing_refs[id], elementbytes)
344+
return isstore(opp) ? getop(ls, first(parents(opp))) : opp
345+
end
325346
ref.loaded[] = true
326347
# ls.sym_to_ref_aliases[ var ] = ref
327348
# ls.ref_to_sym_aliases[ ref ] = var
@@ -427,7 +448,7 @@ function maybe_cse_load!(ls::LoopSet, expr::Expr, elementbytes::Int = 8)
427448
if id === nothing
428449
add_load!( ls, gensym(:temporary), ref, elementbytes )
429450
else
430-
getop(ls, ls.syms_aliasing_refs[id])
451+
getop(ls, ls.syms_aliasing_refs[id], elementbytes)
431452
end
432453
# id = includesarray(ls, array)
433454
# if id > 0
@@ -440,12 +461,7 @@ function add_parent!(
440461
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet, var, elementbytes::Int = 8
441462
)
442463
parent = if var isa Symbol
443-
get!(ls.opdict, var) do
444-
# might add constant
445-
op = add_constant!(ls, var, elementbytes)
446-
pushpreamble!(ls, Expr(:(=), mangledvar(op), var))
447-
op
448-
end
464+
getop(ls, var, elementbytes)
449465
elseif var isa Expr #CSE candidate
450466
maybe_cse_load!(ls, var, elementbytes)
451467
else # assumed constant
@@ -465,7 +481,7 @@ function add_reduction_update_parent!(
465481
parents::Vector{Operation}, deps::Vector{Symbol}, reduceddeps::Vector{Symbol}, ls::LoopSet,
466482
var::Symbol, instr::Symbol, elementbytes::Int = 8
467483
)
468-
parent = getop(ls, var)
484+
parent = getop(ls, var, elementbytes)
469485
setdiffv!(reduceddeps, deps, loopdependencies(parent))
470486
pushparent!(parents, deps, reduceddeps, parent) # deps and reduced deps will not be disjoint
471487
op = Operation(length(operations(ls)), var, elementbytes, instr, compute, deps, reduceddeps, parents)
@@ -502,23 +518,33 @@ end
502518
function add_store!(
503519
ls::LoopSet, var::Symbol, ref::ArrayReference, elementbytes::Int = 8
504520
)
505-
parent = getop(ls, var)
506-
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, loopdependencies(ref), reduceddependencies(parent), [parent], ref )
521+
# @show loopdependencies(ref)
522+
# @show ls.operations
523+
ldref = loopdependencies(ref)
524+
parent = getop(ls, var, ldref, elementbytes)
525+
pvar = parent.variable
526+
if pvar ls.syms_aliasing_refs
527+
push!(ls.syms_aliasing_refs, pvar)
528+
push!(ls.refs_aliasing_syms, ref)
529+
end
530+
op = Operation( length(operations(ls)), ref.array, elementbytes, :setindex!, memstore, ldref, reduceddependencies(parent), [parent], ref )
531+
# @show loopdependencies(op) op
507532
add_vptr!(ls, ref.array, identifier(op), ref.ptr)
508533
pushop!(ls, op, ref.array)
509534
end
510535
function add_store_ref!(ls::LoopSet, var::Symbol, ex::Expr, elementbytes::Int = 8)
511-
ref = ref_from_ref(ex)
536+
ref = ref_from_ref(ex)::ArrayReference
512537
add_store!(ls, var, ref, elementbytes)
513538
end
514539
function add_store_setindex!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
515-
ref = ref_from_setindex(ex)
516-
add_store!(ls, var, ref, elementbytes)
540+
ref = ref_from_setindex(ex)::ArrayReference
541+
add_store!(ls, (ex.args[2])::Symbol, ref, elementbytes)
517542
end
518543
# add operation assigns X to var
519544
function add_operation!(
520545
ls::LoopSet, LHS::Symbol, RHS::Expr, elementbytes::Int = 8
521546
)
547+
# @show LHS, RHS
522548
if RHS.head === :ref
523549
add_load_ref!(ls, LHS, RHS, elementbytes)
524550
elseif RHS.head === :call
@@ -539,11 +565,17 @@ end
539565
function add_operation!(
540566
ls::LoopSet, LHS_sym::Symbol, RHS::Expr, LHS_ref::ArrayReference, elementbytes::Int = 8
541567
)
568+
# @show LHS_sym, RHS
542569
if RHS.head === :ref# || (RHS.head === :call && first(RHS.args) === :getindex)
543570
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
544571
elseif RHS.head === :call
545-
if first(RHS.args) === :getindex
572+
f = first(RHS.args)
573+
if f === :getindex
546574
add_load!(ls, LHS_sym, LHS_ref, elementbytes)
575+
elseif f === :zero || f === :one
576+
c = gensym(:constant)
577+
pushpreamble!(ls, Expr(:(=), c, RHS))
578+
add_constant!(ls, c, [keys(ls.loops)...], LHS_sym, elementbytes)
547579
else
548580
add_compute!(ls, LHS_sym, RHS, elementbytes, LHS_ref)
549581
end
@@ -552,6 +584,7 @@ function add_operation!(
552584
end
553585
end
554586
function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
587+
# @show ex
555588
if ex.head === :call
556589
finex = first(ex.args)::Symbol
557590
if finex === :setindex!
@@ -566,21 +599,25 @@ function Base.push!(ls::LoopSet, ex::Expr, elementbytes::Int = 8)
566599
if RHS isa Expr
567600
add_operation!(ls, LHS, RHS, elementbytes)
568601
else
602+
# @show [keys(ls.loops)...]
569603
add_constant!(ls, RHS, [keys(ls.loops)...], LHS, elementbytes)
570604
end
571605
elseif LHS isa Expr
572606
@assert LHS.head === :ref
573607
local lrhs::Symbol
608+
# @show LHS, RHS
574609
if RHS isa Symbol
575610
lrhs = RHS
576611
elseif RHS isa Expr
577612
# need to check of LHS appears in RHS
578613
# assign RHS to lrhs
579614
ref = ArrayReference(LHS)
580615
id = findfirst(r -> r == ref, ls.refs_aliasing_syms)
581-
lrhs = id === nothing ? gensym(:RHS) : ls.syms_aliasing_refs[id]
582-
# we pass ref, so it can compare references within RHS, and realize
583-
# they equal lrhs
616+
lrhs = if id === nothing
617+
gensym(:RHS)
618+
else
619+
ls.syms_aliasing_refs[id]
620+
end
584621
add_operation!(ls, lrhs, RHS, ref, elementbytes)
585622
end
586623
add_store_ref!(ls, lrhs, LHS, elementbytes)

src/lowering.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ function reduce_unroll!(q, op, U, unrolled)
225225
return U, isunrolled
226226
end
227227
unrolled reduceddependencies(op) || return U
228-
var = pvariable_name(op, suffix)
228+
var = mangledvar(op)
229229
instr = first(parents(op)).instruction
230230
reduce_expr!(q, var, instr, U) # assigns reduction to storevar
231231
1, isunrolled

src/operations.jl

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,13 @@ function Base.hash(x::ArrayReference, h::UInt)
2323
end
2424
hash(x.array, h)
2525
end
26-
loopdependencies(ref::ArrayReference) = filter(i -> i isa Symbol, ref.ref)
26+
function loopdependencies(ref::ArrayReference)
27+
ld = Symbol[]
28+
for r ref.ref
29+
r isa Symbol && push!(ld, r)
30+
end
31+
ld
32+
end
2733
function Base.isequal(x::ArrayReference, y::ArrayReference)
2834
x.array === y.array || return false
2935
nrefs = length(x.ref)
@@ -74,30 +80,18 @@ end
7480

7581
# TODO: can some computations be cached in the operations?
7682
"""
77-
if ooperation_type == memstore || operation_type == memstore# || operation_type == compute_new || operation_type == compute_update
78-
symbolic metadata contains info on direct dependencies / placement within loop.
79-
80-
if isload(op) -> Symbol(:vptr_, first(op.reduced_deps))
81-
if istore(op) -> Symbol(:vptr_, op.variable)
82-
is how we access the memory.
83-
84-
is the stride for loop index
85-
symbolic_metadata[i]
8683
"""
8784
struct Operation
8885
identifier::Int
8986
variable::Symbol
9087
elementbytes::Int
9188
instruction::Instruction
9289
node_type::OperationType
93-
dependencies::Vector{Symbol}#::Vector{Symbol}
90+
dependencies::Vector{Symbol}
9491
reduced_deps::Vector{Symbol}
9592
parents::Vector{Operation}
9693
ref::ArrayReference
9794
mangledvariable::Symbol
98-
# children::Vector{Operation}
99-
# numerical_metadata::Vector{Int} # stride of -1 indicates dynamic
100-
# symbolic_metadata::Vector{Symbol}
10195
function Operation(
10296
identifier::Int,
10397
variable,

test/runtests.jl

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,33 @@ using LinearAlgebra
3333
@test logsumexp!(r, x) 102.35216846104409
3434

3535
@testset "GEMM" begin
36-
AmulBq = :(for m 1:size(A,1), n 1:size(B,2)
36+
using LoopVectorization, Test
37+
U, T = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
38+
AmulBq1 = :(for m 1:size(A,1), n 1:size(B,2)
39+
C[m,n] = zeroB
40+
for k 1:size(A,2)
41+
C[m,n] += A[m,k] * B[k,n]
42+
end
43+
end)
44+
lsAmulB1 = LoopVectorization.LoopSet(AmulBq1);
45+
@test LoopVectorization.choose_order(lsAmulB1) == (Symbol[:n,:m,:k], :m, U, T)
46+
AmulBq2 = :(for m 1:M, n 1:N
47+
C[m,n] = zero(eltype(B))
48+
for k 1:K
49+
C[m,n] += A[m,k] * B[k,n]
50+
end
51+
end)
52+
lsAmulB2 = LoopVectorization.LoopSet(AmulBq2);
53+
@test LoopVectorization.choose_order(lsAmulB2) == (Symbol[:n,:m,:k], :m, U, T)
54+
AmulBq3 = :(for m 1:size(A,1), n 1:size(B,2)
3755
ΔCₘₙ = zero(eltype(C))
3856
for k 1:size(A,2)
3957
ΔCₘₙ += A[m,k] * B[k,n]
4058
end
4159
C[m,n] += ΔCₘₙ
4260
end)
43-
44-
lsAmulB = LoopVectorization.LoopSet(AmulBq);
45-
U, T = LoopVectorization.VectorizationBase.REGISTER_COUNT == 16 ? (3, 4) : (4, 4)
46-
@test LoopVectorization.choose_order(lsAmulB) == (Symbol[:n,:m,:k], :m, U, T)
61+
lsAmulB3 = LoopVectorization.LoopSet(AmulBq3);
62+
@test LoopVectorization.choose_order(lsAmulB3) == (Symbol[:n,:m,:k], :m, U, T)
4763

4864
function AmulB!(C, A, B)
4965
C .= 0
@@ -53,7 +69,7 @@ using LinearAlgebra
5369
end
5470
end
5571
end
56-
function AmulBavx!(C, A, B)
72+
function AmulBavx1!(C, A, B)
5773
@avx for m 1:size(A,1), n 1:size(B,2)
5874
Cₘₙ = zero(eltype(C))
5975
for k 1:size(A,2)
@@ -62,6 +78,23 @@ using LinearAlgebra
6278
C[m,n] = Cₘₙ
6379
end
6480
end
81+
function AmulBavx2!(C, A, B)
82+
z = zero(eltype(C))
83+
@avx for m 1:size(A,1), n 1:size(B,2)
84+
C[m,n] = z
85+
for k 1:size(A,2)
86+
C[m,n] += A[m,k] * B[k,n]
87+
end
88+
end
89+
end
90+
function AmulBavx3!(C, A, B)
91+
@avx for m 1:size(A,1), n 1:size(B,2)
92+
C[m,n] = zero(eltype(C))
93+
for k 1:size(A,2)
94+
C[m,n] += A[m,k] * B[k,n]
95+
end
96+
end
97+
end
6598
function AmuladdBavx!(C, A, B, factor = 1)
6699
@avx for m 1:size(A,1), n 1:size(B,2)
67100
ΔCₘₙ = zero(eltype(C))
@@ -171,8 +204,12 @@ using LinearAlgebra
171204
C = Matrix{TC}(undef, M, N);
172205
A = rand(R, M, K); B = rand(R, K, N);
173206
C2 = similar(C);
174-
AmulBavx!(C, A, B)
175207
AmulB!(C2, A, B)
208+
AmulBavx1!(C, A, B)
209+
@test C C2
210+
fill!(C, 999.99); AmulBavx2!(C, A, B)
211+
@test C C2
212+
fill!(C, 999.99); AmulBavx3!(C, A, B)
176213
@test C C2
177214
fill!(C, 0.0); AmuladdBavx!(C, A, B)
178215
@test C C2

0 commit comments

Comments
 (0)