Skip to content

Commit 189d49d

Browse files
authored
Make dot consistently return zero on empty arrays (#1494)
1 parent 1c0b673 commit 189d49d

File tree

4 files changed

+20
-14
lines changed

4 files changed

+20
-14
lines changed

src/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,8 @@ function dot(x::AbstractArray, y::AbstractArray)
993993
throw(DimensionMismatch(lazy"first array has length $(lx) which does not match the length of the second, $(length(y))."))
994994
end
995995
if lx == 0
996-
return dot(zero(eltype(x)), zero(eltype(y)))
996+
# make sure the returned result equals exactly the zero element
997+
return zero(dot(zero(eltype(x)), zero(eltype(y))))
997998
end
998999
s = zero(dot(first(x), first(y)))
9991000
for (Ix, Iy) in zip(eachindex(x), eachindex(y))

src/symmetric.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,7 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni
573573
if n != size(B, 2)
574574
throw(DimensionMismatch(lazy"A has dimensions $(size(A)) but B has dimensions $(size(B))"))
575575
end
576-
576+
iszero(n) && return $real(zero(dot(zero(eltype(A)), zero(eltype(B)))))
577577
dotprod = $real(zero(dot(first(A), first(B))))
578578
@inbounds if A.uplo == 'U' && B.uplo == 'U'
579579
for j in 1:n

test/matmul.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -696,23 +696,24 @@ end
696696
@test dot(Z, Z) == convert(elty, 34.0)
697697
end
698698

699-
dot1(x, y) = invoke(dot, Tuple{Any,Any}, x, y)
700-
dot2(x, y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x, y)
701699
@testset "generic dot" begin
700+
dot1(x, y) = invoke(dot, Tuple{Any,Any}, x, y)
701+
dot2(x, y) = invoke(dot, Tuple{AbstractArray,AbstractArray}, x, y)
702702
AA = [1+2im 3+4im; 5+6im 7+8im]
703703
BB = [2+7im 4+1im; 3+8im 6+5im]
704704
for A in (copy(AA), view(AA, 1:2, 1:2)), B in (copy(BB), view(BB, 1:2, 1:2))
705705
@test dot(A, B) == dot(vec(A), vec(B)) == dot1(A, B) == dot2(A, B) == dot(float.(A), float.(B))
706-
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
707-
@test_throws MethodError dot(Any[], Any[])
708-
@test_throws MethodError dot1(Any[], Any[])
709-
@test_throws MethodError dot2(Any[], Any[])
710-
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
711-
if n1 != n2
712-
@test_throws DimensionMismatch d(1:n1, 1:n2)
713-
else
714-
@test d(1:n1, 1:n2) norm(1:n1)^2
715-
end
706+
end
707+
@test dot(Int[], Int[]) == 0 == dot1(Int[], Int[]) == dot2(Int[], Int[])
708+
@test dot(ComplexF64[], Float64[]) === dot(ComplexF64[;;], Float64[;;]) === zero(ComplexF64)
709+
@test_throws MethodError dot(Any[], Any[])
710+
@test_throws MethodError dot1(Any[], Any[])
711+
@test_throws MethodError dot2(Any[], Any[])
712+
for n1 = 0:2, n2 = 0:2, d in (dot, dot1, dot2)
713+
if n1 != n2
714+
@test_throws DimensionMismatch d(1:n1, 1:n2)
715+
else
716+
@test d(1:n1, 1:n2) norm(1:n1)^2
716717
end
717718
end
718719
end

test/symmetric.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,10 @@ end
470470
@test dot(symblockmu, symblockml) dot(msymblockmu, msymblockml)
471471
@test dot(symblockml, symblockmu) dot(msymblockml, msymblockmu)
472472
@test dot(symblockml, symblockml) dot(msymblockml, msymblockml)
473+
474+
# empty matrices
475+
@test dot(mtype(ComplexF64[;;], :U), mtype(Float64[;;], :U)) === zero(mtype == Hermitian ? Float64 : ComplexF64)
476+
@test dot(mtype(ComplexF64[;;], :L), mtype(Float64[;;], :L)) === zero(mtype == Hermitian ? Float64 : ComplexF64)
473477
end
474478
end
475479

0 commit comments

Comments
 (0)