|
| 1 | +struct Product{A,B} |
| 2 | + a::A |
| 3 | + b::B |
| 4 | +end |
| 5 | + |
| 6 | +@inline ∗(a::A, b::B) where {A,B} = Product{A,B}(a, b) |
| 7 | +@inline Base.Broadcast.Broadcasted(::typeof(∗), a::A, b::B) where {A, B} = Product{A,B}(a, b) |
| 8 | +# TODO: Need to make this handle A or B being (1 or 2)-D broadcast objects. |
| 9 | +function add_broadcast!( |
| 10 | + ls::LoopSet, mC::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 11 | + ::Type{Product{A,B}}, elementbytes::Int = 8 |
| 12 | +) where {T,A,B} |
| 13 | + K = gensym(:K) |
| 14 | + mA = gensym(:Aₘₖ) |
| 15 | + mB = gensym(:Bₖₙ) |
| 16 | + pushpreamble!(ls, Expr(:(=), mA, Expr(:(.), bcname, QuoteNode(:a)))) |
| 17 | + pushpreamble!(ls, Expr(:(=), mB, Expr(:(.), bcname, QuoteNode(:b)))) |
| 18 | + pushpreamble!(ls, Expr(:(=), K, Expr(:call, :size, mB, 1))) |
| 19 | + |
| 20 | + k = gensym(:k) |
| 21 | + ls.loops[k] = Loop(k, K) |
| 22 | + m = loopsyms[1]; n = loopsyms[2]; |
| 23 | + # load A |
| 24 | + # loadA = add_load!(ls, gensym(:A), productref(A, mA, m, k), elementbytes) |
| 25 | + loadA = add_broadcast!(ls, gensym(:A), mA, [m,k], A, elementbytes) |
| 26 | + # load B |
| 27 | + loadB = add_broadcast!(ls, gensym(:B), mB, [k,n], B, elementbytes) |
| 28 | + # set Cₘₙ = 0 |
| 29 | + setC = add_constant!(ls, 0.0, Symbol[m, k], mC, elementbytes) |
| 30 | + # compute Cₘₙ += Aₘₖ * Bₖₙ |
| 31 | + reductop = Operation( |
| 32 | + ls, mC, elementbytes, :vmuladd, compute, Symbol[m, k, n], Symbol[k], Operation[loadA, loadB, setC] |
| 33 | + ) |
| 34 | + pushop!(ls, reductop, mC) |
| 35 | +end |
| 36 | + |
| 37 | +struct LowDimArray{D,T,N,A<:DenseArray{T,N}} <: DenseArray{T,N} |
| 38 | + data::A |
| 39 | +end |
| 40 | +@inline Base.pointer(A::LowDimArray) = pointer(A) |
| 41 | +function LowDimArray{D}(data::A) where {D,T,N,A <: AbstractArray{T,N}} |
| 42 | + LowDimArray{D,T,N,A}(data) |
| 43 | +end |
| 44 | +function add_broadcast!( |
| 45 | + ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 46 | + ::Type{<:LowDimArray{D,T,N}}, elementbytes::Int = 8 |
| 47 | +) where {D,T,N} |
| 48 | + fulldims = Union{Symbol,Int}[loopsyms[n] for n ∈ 1:N if D[n]] |
| 49 | + ref = ArrayReference(bcname, fulldims, Ref{Bool}(false)) |
| 50 | + add_load!(ls, destname, ref, elementbytes)::Operation |
| 51 | +end |
| 52 | +function add_broadcast!( |
| 53 | + ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 54 | + ::Type{Adjoint{T,A}}, elementbytes::Int = 8 |
| 55 | +) where {T, N, A <: AbstractArray{T,N}} |
| 56 | + ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[N + 1 - n] for n ∈ 1:N], Ref{Bool}(false)) |
| 57 | + add_load!( ls, destname, ref, elementbytes )::Operation |
| 58 | +end |
| 59 | +function add_broadcast!( |
| 60 | + ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 61 | + ::Type{Adjoint{T,V}}, elementbytes::Int = 8 |
| 62 | +) where {T, V <: AbstractVector{T}} |
| 63 | + ref = ArrayReference(bcname, Union{Symbol,Int}[loopsyms[2]], Ref{Bool}(false)) |
| 64 | + add_load!( ls, destname, ref, elementbytes ) |
| 65 | +end |
| 66 | +function add_broadcast!( |
| 67 | + ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 68 | + ::Type{<:AbstractArray{T,N}}, elementbytes::Int = 8 |
| 69 | +) where {T,N} |
| 70 | + add_load!(ls, destname, ArrayReference(bcname, @view(loopsyms[1:N]), Ref{Bool}(false)), elementbytes) |
| 71 | +end |
| 72 | +function add_broadcast!( |
| 73 | + ls::LoopSet, destname::Symbol, bcname::Symbol, loopsyms::Vector{Symbol}, |
| 74 | + ::Type{Broadcasted{DefaultArrayStyle{N},Nothing,F,A}}, |
| 75 | + elementbytes::Int = 8 |
| 76 | +) where {N,F,A} |
| 77 | + instr = get(FUNCTIONSYMBOLS, F) do |
| 78 | + f = gensym(:f) |
| 79 | + pushpreamble!(ls, Expr(:(=), f, Expr(:(.), bcname, QuoteNode(:f)))) |
| 80 | + f |
| 81 | + end |
| 82 | + args = A.parameters |
| 83 | + Nargs = length(args) |
| 84 | + bcargs = Expr(:(.), bcname, QuoteNode(:args)) |
| 85 | + # this is the var name in the loop |
| 86 | + parents = Operation[] |
| 87 | + deps = Symbol[] |
| 88 | + reduceddeps = Symbol[] |
| 89 | + for (i,arg) ∈ enumerate(args) |
| 90 | + argname = gensym(:arg) |
| 91 | + pushpreamble!(ls, Expr(:(=), argname, Expr(:ref, bcargs, i))) |
| 92 | + # dynamic dispatch |
| 93 | + parent = add_broadcast!(ls, gensym(:temp), argname, loopsyms, arg)::Operation |
| 94 | + pushparent!(parents, deps, reduceddeps, parent) |
| 95 | + end |
| 96 | + op = Operation( |
| 97 | + length(operations(ls)), destname, elementbytes, instr, compute, deps, reduceddeps, parents |
| 98 | + ) |
| 99 | + pushop!(ls, op, destname) |
| 100 | +end |
| 101 | + |
| 102 | +# size of dest determines loops |
| 103 | +# @generated |
| 104 | +function vmaterialize!( |
| 105 | + dest::AbstractArray{T,N}, bc::BC |
| 106 | +# ) where {T, N, BC <: Broadcasted} |
| 107 | +) where {N, T, BC <: Broadcasted} |
| 108 | + # we have an N dimensional loop. |
| 109 | + # need to construct the LoopSet |
| 110 | + loopsyms = [gensym(:n) for n ∈ 1:N] |
| 111 | + ls = LoopSet() |
| 112 | + sizes = Expr(:tuple,) |
| 113 | + for (n,itersym) ∈ enumerate(loopsyms) |
| 114 | + Nsym = gensym(:N) |
| 115 | + ls.loops[itersym] = Loop(itersym, Nsym) |
| 116 | + push!(sizes.args, Nsym) |
| 117 | + end |
| 118 | + pushpreamble!(ls, Expr(:(=), sizes, Expr(:call, :size, :dest))) |
| 119 | + add_broadcast!(ls, :dest, :bc, loopsyms, BC) |
| 120 | + add_store!(ls, :dest, ArrayReference(:dest, loopsyms, Ref{Bool}(false))) |
| 121 | + resize!(ls.loop_order, num_loops(ls)) # num_loops may be greater than N, eg Product |
| 122 | + # lower(ls) |
| 123 | + ls |
| 124 | +end |
| 125 | + |
| 126 | +function vmaterialize(bc::Broadcasted) |
| 127 | + ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args) |
| 128 | + vmaterialize!(similar(bc, ElType), bc) |
| 129 | +end |
0 commit comments