9292
9393
9494@inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractMatrix{TB} ) where {TA,TB}
95- # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96- M, KA = size (A)
97- KB, N = size (B)
98- @assert KA == KB " Size mismatch."
95+ # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
96+ M, KA = size (A)
97+ KB, N = size (B)
98+ @assert KA == KB " Size mismatch."
99+ if M === StaticInt (1 )
100+ transpose (Vector {promote_type(TA,TB)} (undef, N)), (M, KA, N)
101+ else
99102 Matrix {promote_type(TA,TB)} (undef, M, N), (M, KA, N)
103+ end
100104end
101105@inline function alloc_matmul_product (A:: AbstractArray{TA} , B:: AbstractVector{TB} ) where {TA,TB}
102106 # TODO : if `M` and `N` are statically sized, shouldn't return a `Matrix`.
@@ -105,10 +109,11 @@ end
105109 @assert KA == KB " Size mismatch."
106110 Vector {promote_type(TA,TB)} (undef, M), (M, KA, One ())
107111end
112+
108113@inline function matmul_serial (A:: AbstractMatrix , B:: AbstractVecOrMat )
109- C, (M,K,N) = alloc_matmul_product (A, B)
110- _matmul_serial ! (C, A, B, One (), Zero (), (M,K,N))
111- return C
114+ C, (M,K,N) = alloc_matmul_product (A, B)
115+ matmul_serial ! (C, A, B, One (), Zero (), (M,K,N), ArrayInterface . contiguous_axis (C ))
116+ return C
112117end
113118
114119
@@ -132,12 +137,16 @@ end
132137@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β)
133138 matmul_serial! (C, A, B, α, β, nothing , ArrayInterface. contiguous_axis (C))
134139end
135- @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt{2} )
136- _matmul_serial! (C' , B' , A' , α, β, nothing )
137- return C
140+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, :: Nothing , :: StaticInt{2} )
141+ _matmul_serial! (transpose (C), transpose (B), transpose (A), α, β, nothing )
142+ return C
143+ end
144+ @inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, (M,K,N):: Tuple{Vararg{Integer,3}} , :: StaticInt{2} )
145+ _matmul_serial! (transpose (C), transpose (B), transpose (A), α, β, (N,K,M))
146+ return C
138147end
139148@inline function matmul_serial! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, MKN, :: StaticInt )
140- _matmul_serial! (C, A, B, α, β, nothing )
149+ _matmul_serial! (C, A, B, α, β, MKN )
141150 return C
142151end
143152
212221Multiply matrices `A` and `B`.
213222"""
214223@inline function matmul (A:: AbstractMatrix , B:: AbstractVecOrMat )
215- C, (M,K,N) = alloc_matmul_product (A, B)
216- _matmul ! (C, A, B, One (), Zero (), nothing , (M,K,N))
217- return C
224+ C, (M,K,N) = alloc_matmul_product (A, B)
225+ matmul ! (C, A, B, One (), Zero (), nothing , (M,K,N), ArrayInterface . contiguous_axis (C ))
226+ return C
218227end
219228
220229"""
@@ -235,13 +244,17 @@ end
235244@inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread)
236245 matmul! (C, A, B, α, β, nthread, nothing , ArrayInterface. contiguous_axis (C))
237246end
238- @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt{2} )
239- _matmul! (C' , B' , A' , α, β, nthread, MKN)
240- return C
247+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, :: Nothing , :: StaticInt{2} )
248+ _matmul! (transpose (C), transpose (B), transpose (A), α, β, nthread, nothing )
249+ return C
250+ end
251+ @inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, (M,K,N):: Tuple{Vararg{Integer,3}} , :: StaticInt{2} )
252+ _matmul! (transpose (C), transpose (B), transpose (A), α, β, nthread, (N,K,M))
253+ return C
241254end
242255@inline function matmul! (C:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , α, β, nthread, MKN, :: StaticInt )
243- _matmul! (C, A, B, α, β, nthread, MKN)
244- return C
256+ _matmul! (C, A, B, α, β, nthread, MKN)
257+ return C
245258end
246259
247260@inline function dontpack (pA:: AbstractStridedPointer{Ta} , M, K, :: StaticInt{mc} , :: StaticInt{kc} , :: Type{Tc} , nspawn) where {mc, kc, Tc, Ta}
0 commit comments