Skip to content

Commit 11ed992

Browse files
authored
SIMD support for math intrinsics (#379)
1 parent 8a1785f commit 11ed992

File tree

7 files changed

+201
-5
lines changed

7 files changed

+201
-5
lines changed

lib/intrinsics/Project.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
11
name = "SPIRVIntrinsics"
22
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
33
authors = ["Tim Besard <[email protected]>"]
4-
version = "0.5.3"
4+
version = "0.5.4"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
88
GPUToolbox = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
99
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1010
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1111

12+
[weakdeps]
13+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
14+
15+
[extensions]
16+
SPIRVIntrinsicsSIMDExt = "SIMD"
17+
1218
[compat]
1319
ExprTools = "0.1"
1420
GPUToolbox = "0.2, 0.3, 1"
1521
LLVM = "9.1"
22+
SIMD = "3.6"
1623
SpecialFunctions = "1.3, 2"
1724
julia = "1.10"
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
module SPIRVIntrinsicsSIMDExt
2+
3+
using SPIRVIntrinsics
4+
using SPIRVIntrinsics: @device_override, @device_function, @builtin_ccall, @typed_ccall
5+
using SIMD
6+
import SpecialFunctions
7+
8+
const known_intrinsics = String[]
9+
10+
# Generate vectorized math intrinsics
11+
for N in [2, 3, 4, 8, 16], T in [Float16, Float32, Float64]
12+
VT = :(Vec{$N,$T})
13+
LVT = :(SIMD.LVec{$N,$T})
14+
15+
@eval begin
16+
# Unary operations
17+
@device_override @inline Base.acos(x::$VT) = $VT(@builtin_ccall("acos", $LVT, ($LVT,), x.data))
18+
@device_override @inline Base.acosh(x::$VT) = $VT(@builtin_ccall("acosh", $LVT, ($LVT,), x.data))
19+
@device_function @inline SPIRVIntrinsics.acospi(x::$VT) = $VT(@builtin_ccall("acospi", $LVT, ($LVT,), x.data))
20+
21+
@device_override @inline Base.asin(x::$VT) = $VT(@builtin_ccall("asin", $LVT, ($LVT,), x.data))
22+
@device_override @inline Base.asinh(x::$VT) = $VT(@builtin_ccall("asinh", $LVT, ($LVT,), x.data))
23+
@device_function @inline SPIRVIntrinsics.asinpi(x::$VT) = $VT(@builtin_ccall("asinpi", $LVT, ($LVT,), x.data))
24+
25+
@device_override @inline Base.atan(x::$VT) = $VT(@builtin_ccall("atan", $LVT, ($LVT,), x.data))
26+
@device_override @inline Base.atanh(x::$VT) = $VT(@builtin_ccall("atanh", $LVT, ($LVT,), x.data))
27+
@device_function @inline SPIRVIntrinsics.atanpi(x::$VT) = $VT(@builtin_ccall("atanpi", $LVT, ($LVT,), x.data))
28+
29+
@device_override @inline Base.cbrt(x::$VT) = $VT(@builtin_ccall("cbrt", $LVT, ($LVT,), x.data))
30+
@device_override @inline Base.ceil(x::$VT) = $VT(@builtin_ccall("ceil", $LVT, ($LVT,), x.data))
31+
32+
@device_override @inline Base.cos(x::$VT) = $VT(@builtin_ccall("cos", $LVT, ($LVT,), x.data))
33+
@device_override @inline Base.cosh(x::$VT) = $VT(@builtin_ccall("cosh", $LVT, ($LVT,), x.data))
34+
@device_override @inline Base.cospi(x::$VT) = $VT(@builtin_ccall("cospi", $LVT, ($LVT,), x.data))
35+
36+
@device_override @inline SpecialFunctions.erfc(x::$VT) = $VT(@builtin_ccall("erfc", $LVT, ($LVT,), x.data))
37+
@device_override @inline SpecialFunctions.erf(x::$VT) = $VT(@builtin_ccall("erf", $LVT, ($LVT,), x.data))
38+
39+
@device_override @inline Base.exp(x::$VT) = $VT(@builtin_ccall("exp", $LVT, ($LVT,), x.data))
40+
@device_override @inline Base.exp2(x::$VT) = $VT(@builtin_ccall("exp2", $LVT, ($LVT,), x.data))
41+
@device_override @inline Base.exp10(x::$VT) = $VT(@builtin_ccall("exp10", $LVT, ($LVT,), x.data))
42+
@device_override @inline Base.expm1(x::$VT) = $VT(@builtin_ccall("expm1", $LVT, ($LVT,), x.data))
43+
44+
@device_override @inline Base.abs(x::$VT) = $VT(@builtin_ccall("fabs", $LVT, ($LVT,), x.data))
45+
@device_override @inline Base.floor(x::$VT) = $VT(@builtin_ccall("floor", $LVT, ($LVT,), x.data))
46+
47+
@device_override @inline SpecialFunctions.loggamma(x::$VT) = $VT(@builtin_ccall("lgamma", $LVT, ($LVT,), x.data))
48+
49+
@device_override @inline Base.log(x::$VT) = $VT(@builtin_ccall("log", $LVT, ($LVT,), x.data))
50+
@device_override @inline Base.log2(x::$VT) = $VT(@builtin_ccall("log2", $LVT, ($LVT,), x.data))
51+
@device_override @inline Base.log10(x::$VT) = $VT(@builtin_ccall("log10", $LVT, ($LVT,), x.data))
52+
@device_override @inline Base.log1p(x::$VT) = $VT(@builtin_ccall("log1p", $LVT, ($LVT,), x.data))
53+
@device_function @inline SPIRVIntrinsics.logb(x::$VT) = $VT(@builtin_ccall("logb", $LVT, ($LVT,), x.data))
54+
55+
@device_function @inline SPIRVIntrinsics.rint(x::$VT) = $VT(@builtin_ccall("rint", $LVT, ($LVT,), x.data))
56+
@device_override @inline Base.round(x::$VT) = $VT(@builtin_ccall("round", $LVT, ($LVT,), x.data))
57+
@device_function @inline SPIRVIntrinsics.rsqrt(x::$VT) = $VT(@builtin_ccall("rsqrt", $LVT, ($LVT,), x.data))
58+
59+
@device_override @inline Base.sin(x::$VT) = $VT(@builtin_ccall("sin", $LVT, ($LVT,), x.data))
60+
@device_override @inline Base.sinh(x::$VT) = $VT(@builtin_ccall("sinh", $LVT, ($LVT,), x.data))
61+
@device_override @inline Base.sinpi(x::$VT) = $VT(@builtin_ccall("sinpi", $LVT, ($LVT,), x.data))
62+
63+
@device_override @inline Base.sqrt(x::$VT) = $VT(@builtin_ccall("sqrt", $LVT, ($LVT,), x.data))
64+
65+
@device_override @inline Base.tan(x::$VT) = $VT(@builtin_ccall("tan", $LVT, ($LVT,), x.data))
66+
@device_override @inline Base.tanh(x::$VT) = $VT(@builtin_ccall("tanh", $LVT, ($LVT,), x.data))
67+
@device_override @inline Base.tanpi(x::$VT) = $VT(@builtin_ccall("tanpi", $LVT, ($LVT,), x.data))
68+
69+
@device_override @inline SpecialFunctions.gamma(x::$VT) = $VT(@builtin_ccall("tgamma", $LVT, ($LVT,), x.data))
70+
71+
@device_override @inline Base.trunc(x::$VT) = $VT(@builtin_ccall("trunc", $LVT, ($LVT,), x.data))
72+
73+
# Binary operations
74+
@device_override @inline Base.atan(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2", $LVT, ($LVT, $LVT), y.data, x.data))
75+
@device_function @inline SPIRVIntrinsics.atanpi(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2pi", $LVT, ($LVT, $LVT), y.data, x.data))
76+
77+
@device_override @inline Base.copysign(x::$VT, y::$VT) = $VT(@builtin_ccall("copysign", $LVT, ($LVT, $LVT), x.data, y.data))
78+
@device_function @inline SPIRVIntrinsics.dim(x::$VT, y::$VT) = $VT(@builtin_ccall("fdim", $LVT, ($LVT, $LVT), x.data, y.data))
79+
80+
@device_override @inline Base.hypot(x::$VT, y::$VT) = $VT(@builtin_ccall("hypot", $LVT, ($LVT, $LVT), x.data, y.data))
81+
82+
@device_override @inline Base.max(x::$VT, y::$VT) = $VT(@builtin_ccall("fmax", $LVT, ($LVT, $LVT), x.data, y.data))
83+
@device_override @inline Base.min(x::$VT, y::$VT) = $VT(@builtin_ccall("fmin", $LVT, ($LVT, $LVT), x.data, y.data))
84+
85+
@device_function @inline SPIRVIntrinsics.maxmag(x::$VT, y::$VT) = $VT(@builtin_ccall("maxmag", $LVT, ($LVT, $LVT), x.data, y.data))
86+
@device_function @inline SPIRVIntrinsics.minmag(x::$VT, y::$VT) = $VT(@builtin_ccall("minmag", $LVT, ($LVT, $LVT), x.data, y.data))
87+
88+
@device_function @inline SPIRVIntrinsics.nextafter(x::$VT, y::$VT) = $VT(@builtin_ccall("nextafter", $LVT, ($LVT, $LVT), x.data, y.data))
89+
90+
@device_override @inline Base.:(^)(x::$VT, y::$VT) = $VT(@builtin_ccall("pow", $LVT, ($LVT, $LVT), x.data, y.data))
91+
@device_function @inline SPIRVIntrinsics.powr(x::$VT, y::$VT) = $VT(@builtin_ccall("powr", $LVT, ($LVT, $LVT), x.data, y.data))
92+
93+
@device_override @inline Base.rem(x::$VT, y::$VT) = $VT(@builtin_ccall("remainder", $LVT, ($LVT, $LVT), x.data, y.data))
94+
95+
# Ternary operations
96+
@device_override @inline Base.fma(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("fma", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data))
97+
@device_function @inline SPIRVIntrinsics.mad(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("mad", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data))
98+
end
99+
100+
# Special operations with Int32 parameters
101+
VIntT = :(Vec{$N,Int32})
102+
LVIntT = :(SIMD.LVec{$N,Int32})
103+
104+
@eval begin
105+
@device_function @inline SPIRVIntrinsics.ilogb(x::$VT) = $VIntT(@builtin_ccall("ilogb", $LVIntT, ($LVT,), x.data))
106+
@device_override @inline Base.ldexp(x::$VT, k::$VIntT) = $VT(@builtin_ccall("ldexp", $LVT, ($LVT, $LVIntT), x.data, k.data))
107+
@device_override @inline Base.:(^)(x::$VT, y::$VIntT) = $VT(@builtin_ccall("pown", $LVT, ($LVT, $LVIntT), x.data, y.data))
108+
@device_function @inline SPIRVIntrinsics.rootn(x::$VT, y::$VIntT) = $VT(@builtin_ccall("rootn", $LVT, ($LVT, $LVIntT), x.data, y.data))
109+
end
110+
end
111+
112+
# nan functions - take unsigned integer codes and return floats
113+
for N in [2, 3, 4, 8, 16]
114+
@eval begin
115+
@device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt16}) = Vec{$N,Float16}(@builtin_ccall("nan", SIMD.LVec{$N,Float16}, (SIMD.LVec{$N,UInt16},), nancode.data))
116+
@device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt32}) = Vec{$N,Float32}(@builtin_ccall("nan", SIMD.LVec{$N,Float32}, (SIMD.LVec{$N,UInt32},), nancode.data))
117+
@device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt64}) = Vec{$N,Float64}(@builtin_ccall("nan", SIMD.LVec{$N,Float64}, (SIMD.LVec{$N,UInt64},), nancode.data))
118+
end
119+
end
120+
121+
end # module

lib/intrinsics/src/utils.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,18 @@ macro builtin_ccall(name, ret, argtypes, args...)
3939
error("Unknown type $T")
4040
end
4141
end
42+
mangle(::Type{NTuple{N, VecElement{T}}}) where {N, T} = "Dv$(N)_" * mangle(T)
4243

4344
# C++-style mangling; very limited to just support these intrinsics
4445
# TODO: generalize for use with other intrinsics? do we need to mangle those?
4546
mangled = "_Z$(length(name))$name"
4647
for t in argtypes
4748
# with `@eval @builtin_ccall`, we get actual types in the ast, otherwise symbols
48-
t = (isa(t, Symbol) || isa(t, Expr)) ? eval(t) : t
49+
t = (isa(t, Symbol) || isa(t, Expr)) ? __module__.eval(t) : t
4950
mangled *= mangle(t)
5051
end
5152

52-
push!(known_intrinsics, mangled)
53+
push!(__module__.known_intrinsics, mangled)
5354
esc(quote
5455
@typed_ccall($mangled, llvmcall, $ret, ($(argtypes...),), $(args...))
5556
end)
@@ -63,7 +64,7 @@ Base.Experimental.@MethodTable(method_table)
6364

6465
macro device_override(ex)
6566
esc(quote
66-
Base.Experimental.@overlay(method_table, $ex)
67+
Base.Experimental.@overlay($method_table, $ex)
6768
end)
6869
end
6970

src/compiler/compilation.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1616
Tuple{CompilerJob{SPIRVCompilerTarget}, typeof(fn)},
1717
job, fn) ||
1818
in(fn, known_intrinsics) ||
19+
let SPIRVIntrinsicsSIMDExt = Base.get_extension(SPIRVIntrinsics, :SPIRVIntrinsicsSIMDExt)
20+
SPIRVIntrinsicsSIMDExt !== nothing && in(fn, SPIRVIntrinsicsSIMDExt.known_intrinsics)
21+
end ||
1922
contains(fn, "__spirv_")
2023

2124
GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1212
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1313
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
15+
SIMD = "fdea26ae-647d-5447-a871-4b548cad5224"
1516
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
1617
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
1718
SPIRV_LLVM_Translator_jll = "4a5d46fc-d8cf-5151-a261-86b458210efb"

test/atomics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr
1+
using SPIRVIntrinsics: @builtin_ccall, @typed_ccall, LLVMPtr, known_intrinsics
22

33
@testset "atomics" begin
44

test/intrinsics.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
using SIMD
2+
13
function call_on_device(f, args...)
24
function kernel(res, f, args...)
35
res[] = f(args...)
@@ -11,6 +13,7 @@ end
1113

1214
const float_types = filter(x -> x <: Base.IEEEFloat, GPUArraysTestSuite.supported_eltypes(CLArray))
1315
const ispocl = cl.platform().name == "Portable Computing Language"
16+
const simd_ns = [2, 3, 4, 8, 16]
1417

1518
@testset "intrinsics" begin
1619

@@ -162,6 +165,66 @@ end
162165
@test call_on_device(OpenCL.mad, x, y, z) x * y + z
163166
end
164167

168+
@testset "SIMD - $N x $T" for N in simd_ns, T in float_types
169+
# codegen emits i48 here, which SPIR-V doesn't support
170+
# XXX: fix upstream?
171+
T == Float16 && N == 3 && continue
172+
173+
v = Vec{N, T}(ntuple(_ -> rand(T), N))
174+
175+
# unary ops: sin, cos, sqrt
176+
a = call_on_device(sin, v)
177+
@test all(a[i] sin(v[i]) for i in 1:N)
178+
179+
b = call_on_device(cos, v)
180+
@test all(b[i] cos(v[i]) for i in 1:N)
181+
182+
c = call_on_device(sqrt, v)
183+
@test all(c[i] sqrt(v[i]) for i in 1:N)
184+
185+
# binary ops: max, hypot
186+
w = Vec{N, T}(ntuple(_ -> rand(T), N))
187+
d = call_on_device(max, v, w)
188+
@test all(d[i] == max(v[i], w[i]) for i in 1:N)
189+
190+
broken = ispocl && T == Float16
191+
if !broken
192+
h = call_on_device(hypot, v, w)
193+
@test all(h[i] hypot(v[i], w[i]) for i in 1:N)
194+
end
195+
196+
# ternary op: fma
197+
x = Vec{N, T}(ntuple(_ -> rand(T), N))
198+
e = call_on_device(fma, v, w, x)
199+
@test all(e[i] fma(v[i], w[i], x[i]) for i in 1:N)
200+
201+
# special cases: ilogb, ldexp, ^ with Int32, rootn
202+
v_pos = Vec{N, T}(ntuple(_ -> rand(T) + T(1), N))
203+
@test call_on_device(OpenCL.ilogb, v_pos) isa Vec{N, Int32} broken = broken
204+
205+
k = Vec{N, Int32}(ntuple(_ -> rand(Int32.(-5:5)), N))
206+
@test let
207+
ldexp_result = call_on_device(ldexp, v_pos, k)
208+
all(ldexp_result[i] ldexp(v_pos[i], k[i]) for i in 1:N)
209+
end broken = broken
210+
211+
base = Vec{N, T}(ntuple(_ -> rand(T) + T(0.5), N))
212+
exp_int = Vec{N, Int32}(ntuple(_ -> rand(Int32.(0:3)), N))
213+
@test let
214+
pow_result = call_on_device(^, base, exp_int)
215+
all(pow_result[i] base[i] ^ exp_int[i] for i in 1:N)
216+
end broken = broken
217+
218+
rootn_base = Vec{N, T}(ntuple(_ -> rand(T) * T(10) + T(1), N))
219+
rootn_n = Vec{N, Int32}(ntuple(_ -> rand(Int32.(2:4)), N))
220+
@test call_on_device(OpenCL.rootn, rootn_base, rootn_n) isa Vec{N, T} broken = broken
221+
222+
# special cases: nan
223+
nan_code = Vec{N, Base.uinttype(T)}(ntuple(_ -> rand(Base.uinttype(T)), N))
224+
nan_result = call_on_device(OpenCL.nan, nan_code)
225+
@test all(isnan(nan_result[i]) for i in 1:N)
226+
end
227+
165228
end
166229

167230
end

0 commit comments

Comments
 (0)