Skip to content

Commit 6b5b98f

Browse files
authored
Support Julia 1.12 (#529)
* Add herk! and syrk! * Avoid widemul and Int128 * Julia 1.12 matmatmul! dispatch * Broadcast * Fix return_type * Avoid pointer test in Julia 1.12 for now
1 parent 47782e6 commit 6b5b98f

File tree

9 files changed

+128
-58
lines changed

9 files changed

+128
-58
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ steps:
2626
julia:
2727
- "1.10"
2828
- "1.11"
29+
- "1.12"
2930
- "nightly"
3031
adjustments:
3132
- with:

lib/level-zero/pointer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ Base.eltype(::Type{<:ZePtr{T}}) where {T} = T
4040
Base.convert(::Type{T}, x::ZePtr) where {T<:Integer} = T(UInt(x))
4141
## integer to pointer
4242
Base.convert(::Type{ZePtr{T}}, x::Union{Int,UInt}) where {T} = ZePtr{T}(x)
43-
Int(x::ZePtr) = Base.bitcast(Int, x)
44-
UInt(x::ZePtr) = Base.bitcast(UInt, x)
43+
Base.Int(x::ZePtr) = Base.bitcast(Int, x)
44+
Base.UInt(x::ZePtr) = Base.bitcast(UInt, x)
4545

4646
# between regular and oneAPI pointers
4747
Base.convert(::Type{<:Ptr}, p::ZePtr) =
@@ -71,8 +71,8 @@ Base.:(==)(x::ZePtr, y::ZePtr) = UInt(x) == UInt(y)
7171
Base.:(<)(x::ZePtr, y::ZePtr) = UInt(x) < UInt(y)
7272
Base.:(-)(x::ZePtr, y::ZePtr) = UInt(x) - UInt(y)
7373

74-
Base.:(+)(x::ZePtr, y::Integer) = oftype(x, Base.add_ptr(UInt(x), (y % UInt) % UInt))
75-
Base.:(-)(x::ZePtr, y::Integer) = oftype(x, Base.sub_ptr(UInt(x), (y % UInt) % UInt))
74+
Base.:(+)(x::ZePtr, y::Integer) = oftype(x, UInt(x) + (y % UInt) % UInt)
75+
Base.:(-)(x::ZePtr, y::Integer) = oftype(x, UInt(x) - (y % UInt) % UInt)
7676
Base.:(+)(x::Integer, y::ZePtr) = y + x
7777

7878

lib/mkl/interfaces.jl

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,42 @@
22

33
using LinearAlgebra: BlasComplex, BlasFloat, BlasReal, MulAddMul
44

5-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasFloat
5+
# legacy methods with final MulAddMul argument
6+
LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
7+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
8+
LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
9+
LinearAlgebra.generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
10+
LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
11+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
12+
LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where {T <: Union{Float16, ComplexF16, BlasFloat}} =
13+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
14+
15+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSR{T}, B::oneVector{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
616
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
7-
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
17+
return sparse_gemv!(tA, alpha, A, B, beta, C)
818
end
919

10-
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, _add::MulAddMul) where T <: BlasReal
20+
function LinearAlgebra.generic_matvecmul!(C::oneVector{T}, tA::AbstractChar, A::oneSparseMatrixCSC{T}, B::oneVector{T}, alpha::Number, beta::Number) where {T <: BlasReal}
1121
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
12-
sparse_gemv!(tA, _add.alpha, A, B, _add.beta, C)
22+
return sparse_gemv!(tA, alpha, A, B, beta, C)
1323
end
1424

15-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasFloat
25+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}, alpha::Number, beta::Number) where {T <: BlasFloat}
1626
tA = tA in ('S', 's', 'H', 'h') ? 'N' : tA
1727
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
18-
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
28+
return sparse_gemm!(tA, tB, alpha, A, B, beta, C)
1929
end
2030

21-
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, _add::MulAddMul) where T <: BlasReal
31+
function LinearAlgebra.generic_matmatmul!(C::oneMatrix{T}, tA, tB, A::oneSparseMatrixCSC{T}, B::oneMatrix{T}, alpha::Number, beta::Number) where {T <: BlasReal}
2232
tA = tA in ('S', 's', 'H', 'h') ? 'T' : flip_trans(tA)
2333
tB = tB in ('S', 's', 'H', 'h') ? 'N' : tB
24-
sparse_gemm!(tA, tB, _add.alpha, A, B, _add.beta, C)
34+
return sparse_gemm!(tA, tB, alpha, A, B, beta, C)
2535
end
2636

27-
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where T <: BlasFloat
28-
sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
37+
function LinearAlgebra.generic_trimatdiv!(C::oneVector{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneVector{T}) where {T <: BlasFloat}
38+
return sparse_trsv!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
2939
end
3040

31-
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where T <: BlasFloat
32-
sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
41+
function LinearAlgebra.generic_trimatdiv!(C::oneMatrix{T}, uploc, isunitc, tfun::Function, A::oneSparseMatrixCSR{T}, B::oneMatrix{T}) where {T <: BlasFloat}
42+
return sparse_trsm!(uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', 'N', isunitc, one(T), A, B, C)
3343
end

lib/mkl/linalg.jl

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ function LinearAlgebra.generic_matvecmul!(Y::oneVector, tA::AbstractChar, A::one
104104
end
105105
end
106106
end
107-
LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, MulAddMul(alpha, beta))
107+
return LinearAlgebra.generic_matmatmul!(Y, tA, 'N', A, B, alpha, beta)
108108
end
109109

110110
# triangular
@@ -120,46 +120,71 @@ LinearAlgebra.generic_trimatdiv!(C::oneStridedVector{T}, uploc, isunitc, tfun::F
120120
# BLAS 3
121121
#
122122

123-
LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, _add::MulAddMul=MulAddMul()) =
124-
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
125-
function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStridedVecOrMat, B::oneStridedVecOrMat, a::Number, b::Number)
123+
if VERSION >= v"1.12-"
124+
# Otherwise dispatches onto:
125+
# https://github.com/JuliaLang/LinearAlgebra.jl/blob/4e7c3f40316a956119ac419a97c4b8aad7a17e6c/src/matmul.jl#L490
126+
for blas_flag in (LinearAlgebra.BlasFlag.SyrkHerkGemm, LinearAlgebra.BlasFlag.SymmHemmGeneric)
127+
@eval LinearAlgebra.generic_matmatmul_wrapper!(
128+
C::oneStridedMatrix, tA::AbstractChar, tB::AbstractChar, A::oneStridedVecOrMat, B::oneStridedVecOrMat,
129+
alpha::Number, beta::Number, ::$blas_flag
130+
) =
131+
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, alpha, beta)
132+
end
133+
end
134+
135+
LinearAlgebra.generic_matmatmul!(
136+
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
137+
B::oneStridedVecOrMat, _add::MulAddMul,
138+
) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
139+
function LinearAlgebra.generic_matmatmul!(
140+
C::oneStridedVecOrMat, tA, tB, A::oneStridedVecOrMat,
141+
B::oneStridedVecOrMat, alpha::Number, beta::Number,
142+
)
126143
T = eltype(C)
127-
alpha, beta = promote(a, b, zero(T))
128144
mA, nA = size(A, tA == 'N' ? 1 : 2), size(A, tA == 'N' ? 2 : 1)
129145
mB, nB = size(B, tB == 'N' ? 1 : 2), size(B, tB == 'N' ? 2 : 1)
130146

131-
if nA != mB
132-
throw(DimensionMismatch("A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"))
133-
end
134-
135-
if C === A || B === C
136-
throw(ArgumentError("output matrix must not be aliased with input matrix"))
137-
end
147+
nA != mB && throw(
148+
DimensionMismatch(
149+
"A has dimensions ($mA,$nA) but B has dimensions ($mB,$nB)"
150+
)
151+
)
152+
(C === A || B === C) && throw(
153+
ArgumentError(
154+
"output matrix must not be aliased with input matrix"
155+
)
156+
)
138157

139158
if mA == 0 || nA == 0 || nB == 0
140-
if size(C) != (mA, nB)
141-
throw(DimensionMismatch("C has dimensions $(size(C)), should have ($mA,$nB)"))
142-
end
159+
size(C) != (mA, nB) && throw(
160+
DimensionMismatch(
161+
"C has dimensions $(size(C)), should have ($mA,$nB)"
162+
)
163+
)
143164
return LinearAlgebra.rmul!(C, 0)
144165
end
145166

146-
if all(in(('N', 'T', 'C')), (tA, tB))
147-
if T <: Union{onemklFloat, onemklComplex, onemklHalf} && eltype(A) == eltype(B) == T
148-
return gemm!(tA, tB, alpha, A, B, beta, C)
149-
end
150-
end
167+
T = eltype(C)
168+
151169
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
152170
# TODO: should the gemm part above be included in this branch?
153-
if (tA == 'S' || tA == 's') && tB == 'N'
154-
return symm!('L', tA == 'S' ? 'U' : 'L', alpha, A, B, beta, C)
171+
α, β = T(alpha), T(beta)
172+
if (
173+
all(in(('N', 'T', 'C')), (tA, tB)) && T <: Union{onemklFloat, onemklComplex, onemklHalf} &&
174+
A isa oneStridedArray{T} && B isa oneStridedArray{T}
175+
)
176+
return gemm!(tA, tB, α, A, B, β, C)
177+
elseif (tA == 'S' || tA == 's') && tB == 'N'
178+
return symm!('L', tA == 'S' ? 'U' : 'L', α, A, B, β, C)
155179
elseif (tB == 'S' || tB == 's') && tA == 'N'
156-
return symm!('R', tB == 'S' ? 'U' : 'L', alpha, B, A, beta, C)
180+
return symm!('R', tB == 'S' ? 'U' : 'L', α, B, A, β, C)
157181
elseif (tA == 'H' || tA == 'h') && tB == 'N'
158-
return hemm!('L', tA == 'H' ? 'U' : 'L', alpha, A, B, beta, C)
182+
return hemm!('L', tA == 'H' ? 'U' : 'L', α, A, B, β, C)
159183
elseif (tB == 'H' || tB == 'h') && tA == 'N'
160-
return hemm!('R', tB == 'H' ? 'U' : 'L', alpha, B, A, beta, C)
184+
return hemm!('R', tB == 'H' ? 'U' : 'L', α, B, A, β, C)
161185
end
162186
end
187+
163188
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
164189
end
165190

@@ -172,3 +197,23 @@ LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::F
172197
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B))
173198
LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
174199
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A))
200+
201+
#
202+
# BLAS extensions
203+
#
204+
205+
# Extend LinearAlgebra.BLAS.herk! to dispatch to oneAPI implementation
206+
for (elty) in ([Float32, ComplexF32], [Float64, ComplexF64])
207+
@eval begin
208+
LinearAlgebra.BLAS.herk!(uplo::Char, trans::Char, alpha::$elty[1], A::oneStridedVecOrMat{$elty[2]}, beta::$elty[1], C::oneStridedMatrix{$elty[2]}) =
209+
herk!(uplo, trans, alpha, A, beta, C)
210+
end
211+
end
212+
213+
# Extend LinearAlgebra.BLAS.syrk! to dispatch to oneAPI implementation
214+
for (elty) in (Float32, Float64, ComplexF32, ComplexF64)
215+
@eval begin
216+
LinearAlgebra.BLAS.syrk!(uplo::Char, trans::Char, alpha::$elty, A::oneStridedVecOrMat{$elty}, beta::$elty, C::oneStridedMatrix{$elty}) =
217+
syrk!(uplo, trans, alpha, A, beta, C)
218+
end
219+
end

lib/mkl/wrappers_sparse.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
2929
return dA
3030
end
3131

32-
function SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
32+
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSR{$elty, $intty})
3333
handle_ptr = Ref{matrix_handle_t}()
3434
At = SparseMatrixCSC(reverse(A.dims)..., Vector(A.rowPtr), Vector(A.colVal), Vector(A.nzVal))
3535
A_csc = SparseMatrixCSC(At |> transpose)
@@ -51,7 +51,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_csr_data , :Float32 , :Int3
5151
return dA
5252
end
5353

54-
function SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
54+
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCSC{$elty, $intty})
5555
handle_ptr = Ref{matrix_handle_t}()
5656
A_csc = SparseMatrixCSC(A.dims..., Vector(A.colPtr), Vector(A.rowVal), Vector(A.nzVal))
5757
return A_csc
@@ -84,7 +84,7 @@ for (fname, elty, intty) in ((:onemklSsparse_set_coo_data , :Float32 , :Int3
8484
return dA
8585
end
8686

87-
function SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
87+
function SparseArrays.SparseMatrixCSC(A::oneSparseMatrixCOO{$elty, $intty})
8888
handle_ptr = Ref{matrix_handle_t}()
8989
A = sparse(Vector(A.rowInd), Vector(A.colInd), Vector(A.nzVal), A.dims...)
9090
return A

src/broadcast.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
# broadcasting
2-
3-
using Base.Broadcast: BroadcastStyle, Broadcasted
1+
import Base.Broadcast: BroadcastStyle, Broadcasted
42

53
struct oneArrayStyle{N,B} <: AbstractGPUArrayStyle{N} end
64
oneArrayStyle{M,B}(::Val{N}) where {N,M,B} = oneArrayStyle{N,B}()
75

86
# identify the broadcast style of a (wrapped) oneArray
9-
BroadcastStyle(::Type{<:oneArray{T,N,B}}) where {T,N,B} = oneArrayStyle{N,B}()
10-
BroadcastStyle(W::Type{<:oneWrappedArray{T,N}}) where {T,N} =
7+
BroadcastStyle(::Type{<:oneArray{T, N, B}}) where {T, N, B} = oneArrayStyle{N, B}()
8+
BroadcastStyle(W::Type{<:oneWrappedArray{T, N}}) where {T, N} =
119
oneArrayStyle{N, buftype(Adapt.unwrap_type(W))}()
1210

1311
# when we are dealing with different buffer styles, we cannot know
1412
# which one is better, so use shared memory
15-
BroadcastStyle(::oneArrayStyle{N, B1},
16-
::oneArrayStyle{N, B2}) where {N,B1,B2} =
13+
BroadcastStyle(
14+
::oneArrayStyle{N, B1},
15+
::oneArrayStyle{N, B2},
16+
) where {N,B1,B2} =
1717
oneArrayStyle{N, oneL0.SharedBuffer}()
1818

1919
# allocation of output arrays

src/compiler/reflection.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ export @device_code_lowered, @device_code_typed, @device_code_warntype,
6565
#
6666

6767
"""
68-
Metal.return_type(f, tt) -> r::Type
68+
return_type(f, tt) -> r::Type
6969
7070
Return a type `r` such that `f(args...)::r` where `args::tt`.
7171
"""
@@ -75,5 +75,5 @@ function return_type(@nospecialize(func), @nospecialize(tt))
7575
job = CompilerJob(source, config)
7676
interp = GPUCompiler.get_interpreter(job)
7777
sig = Base.signature_type(func, tt)
78-
Core.Compiler.return_type(interp, sig)
78+
return Core.Compiler._return_type(interp, sig)
7979
end

src/device/quirks.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,13 @@ end
5757
@print_and_throw "Out-of-bounds access of scalar value"
5858
x
5959
end
60+
61+
# From Metal.jl to avoid widemul and Int128
62+
@static if VERSION >= v"1.12.0-DEV.1736" # Partially reverts JuliaLang/julia PR #56750
63+
let BitInteger64 = Union{Int64, UInt64}
64+
@device_override function Base.checkbounds(::Type{Bool}, v::StepRange{<:BitInteger64, <:BitInteger64}, i::BitInteger64)
65+
@inline
66+
return checkindex(Bool, eachindex(IndexLinear(), v), i)
67+
end
68+
end
69+
end

test/execution.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,12 @@ end
307307
@oneapi kernel(arr)
308308
@test Array(arr)[] == 1
309309

310-
function kernel(ptr)
310+
function kernel2(ptr)
311311
ptr[] = 2
312312
return
313313
end
314314

315-
@oneapi kernel(arr)
315+
@oneapi kernel2(arr)
316316
@test Array(arr)[] == 2
317317
end
318318

@@ -611,9 +611,13 @@ end
611611
return
612612
end
613613

614-
a = oneArray(Float32[0])
615-
@oneapi kernel(pointer(a))
616-
@test Array(a) == [42]
614+
if VERSION < v"1.12"
615+
a = oneArray(Float32[0])
616+
@oneapi kernel(pointer(a))
617+
@test Array(a) == [42]
618+
else
619+
@test_broken false
620+
end
617621
end
618622

619623
############################################################################################

0 commit comments

Comments
 (0)