Skip to content

Commit d2403f5

Browse files
authored
add Float16 math functions (#368)
1 parent a6f878a commit d2403f5

File tree

4 files changed

+132
-9
lines changed

4 files changed

+132
-9
lines changed

lib/intrinsics/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SPIRVIntrinsics"
22
uuid = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
33
authors = ["Tim Besard <[email protected]>"]
4-
version = "0.5.2"
4+
version = "0.5.3"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

lib/intrinsics/src/math.jl

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Math Functions
22

33
# TODO: vector types
4-
const generic_types = [Float32,Float64]
4+
const generic_types = [Float16, Float32, Float64]
55
const generic_types_float = [Float32]
66
const generic_types_double = [Float64]
77

@@ -33,7 +33,7 @@ for gentype in generic_types
3333

3434
@device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x)
3535
@device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x)
36-
@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)
36+
@device_override Base.cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x)
3737

3838
@device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x)
3939
@device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x)
@@ -59,7 +59,10 @@ for gentype in generic_types
5959
#@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y)
6060
# fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr)
6161

62-
@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
62+
# TODO: remove once https://github.com/pocl/pocl/issues/2034 is addressed
63+
if $gentype != Float16
64+
@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y)
65+
end
6366

6467
@device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x)
6568

@@ -81,8 +84,6 @@ for gentype in generic_types
8184
@device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y)
8285
@device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y)
8386

84-
@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y)
85-
8687
@device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x)
8788

8889
@device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x)
@@ -100,13 +101,13 @@ for gentype in generic_types
100101
return sinval, cosval[]
101102
end
102103
@device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x)
103-
@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)
104+
@device_override Base.sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x)
104105

105106
@device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x)
106107

107108
@device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x)
108109
@device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x)
109-
@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)
110+
@device_override Base.tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x)
110111

111112
@device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x)
112113

@@ -151,11 +152,13 @@ end
151152
# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp)
152153
# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp)
153154

155+
@device_function ilogb(x::Float16) = @builtin_ccall("ilogb", Int32, (Float16,), x)
154156
# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x)
155157
@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x)
156158
# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x)
157159
@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x)
158160

161+
@device_override Base.ldexp(x::Float16, k::Int32) = @builtin_ccall("ldexp", Float16, (Float16, Int32), x, k)
159162
# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k)
160163
# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k)
161164
@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k)
@@ -168,11 +171,13 @@ end
168171
# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp)
169172
# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp)
170173

174+
@device_function nan(nancode::UInt16) = @builtin_ccall("nan", Float16, (UInt16,), nancode)
171175
# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode)
172176
@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode)
173177
# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode)
174178
@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode)
175179

180+
@device_override Base.:(^)(x::Float16, y::Int32) = @builtin_ccall("pown", Float16, (Float16, Int32), x, y)
176181
# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y)
177182
@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y)
178183
# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y)
@@ -183,10 +188,11 @@ end
183188
# remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo)
184189
# remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo)
185190

191+
@device_function rootn(x::Float16, y::Int32) = @builtin_ccall("rootn", Float16, (Float16, Int32), x, y)
186192
# rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y)
187193
@device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y)
188194
# rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y)
189-
# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y)
195+
@device_function rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64, (Float64, Int32), x, y)
190196

191197

192198
# TODO: half and native

lib/intrinsics/src/utils.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ macro builtin_ccall(name, ret, argtypes, args...)
2626
"c"
2727
elseif T == UInt8
2828
"h"
29+
elseif T == Float16
30+
"Dh"
2931
elseif T == Float32
3032
"f"
3133
elseif T == Float64

test/intrinsics.jl

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
function call_on_device(f, args...)
2+
function kernel(res, f, args...)
3+
res[] = f(args...)
4+
return
5+
end
6+
T = OpenCL.code_typed(() -> f(args...), ())[][2]
7+
res = CLArray{T, 0}(undef)
8+
@opencl kernel(res, f, args...)
9+
return OpenCL.@allowscalar res[]
10+
end
11+
12+
const float_types = filter(x -> x <: Base.IEEEFloat, GPUArraysTestSuite.supported_eltypes(CLArray))
13+
const ispocl = cl.platform().name == "Portable Computing Language"
14+
115
@testset "intrinsics" begin
216

317
@testset "barrier" begin
@@ -49,4 +63,105 @@ cl.memory_backend() isa cl.SVMBackend && @on_device atomic_work_item_fence(OpenC
4963

5064
end
5165

66+
@testset "math" begin
67+
68+
@testset "unary - $T" for T in float_types
69+
@testset "$f" for f in [
70+
acos, acosh,
71+
asin, asinh,
72+
atan, atanh,
73+
cbrt,
74+
ceil,
75+
cos, cosh, cospi,
76+
exp, exp2, exp10, expm1,
77+
abs,
78+
floor,
79+
log, log2, log10, log1p,
80+
round,
81+
sin, sinh, sinpi,
82+
sqrt,
83+
tan, tanh, tanpi,
84+
trunc,
85+
]
86+
x = rand(T)
87+
if f == acosh
88+
x += 1
89+
end
90+
broken = ispocl && T == Float16 && f in [acosh, asinh, atanh, cbrt, cospi, expm1, log1p, sinpi, tanpi]
91+
@test call_on_device(f, x) f(x) broken = broken
92+
end
93+
end
94+
95+
@testset "binary - $T" for T in float_types
96+
@testset "$f" for f in [
97+
atan,
98+
copysign,
99+
max,
100+
min,
101+
hypot,
102+
(^),
103+
]
104+
x = rand(T)
105+
y = rand(T)
106+
broken = ispocl && T == Float16 && f == atan
107+
@test call_on_device(f, x, y) f(x, y) broken = broken
108+
end
109+
end
110+
111+
@testset "ternary - $T" for T in float_types
112+
@testset "$f" for f in [
113+
fma,
114+
]
115+
x = rand(T)
116+
y = rand(T)
117+
z = rand(T)
118+
@test call_on_device(f, x, y, z) f(x, y, z)
119+
end
120+
end
121+
122+
@testset "OpenCL-specific unary - $T" for T in float_types
123+
@testset "$f" for f in [
124+
OpenCL.acospi,
125+
OpenCL.asinpi,
126+
OpenCL.atanpi,
127+
OpenCL.logb,
128+
OpenCL.rint,
129+
OpenCL.rsqrt,
130+
]
131+
x = rand(T)
132+
broken = ispocl && T == Float16 && !(f in [OpenCL.rint, OpenCL.rsqrt])
133+
@test call_on_device(f, x) isa Real broken = broken # Just check it doesn't error
134+
end
135+
broken = ispocl && T == Float16
136+
@test call_on_device(OpenCL.ilogb, T(8.0)) isa Int32 broken = broken
137+
@test call_on_device(OpenCL.nan, Base.uinttype(T)(0)) isa T
138+
end
139+
140+
@testset "OpenCL-specific binary - $T" for T in float_types
141+
@testset "$f" for f in [
142+
OpenCL.atanpi,
143+
OpenCL.dim,
144+
OpenCL.maxmag,
145+
OpenCL.minmag,
146+
OpenCL.nextafter,
147+
OpenCL.powr,
148+
]
149+
x = rand(T)
150+
y = rand(T)
151+
broken = ispocl && T == Float16 && !(f in [OpenCL.maxmag, OpenCL.minmag])
152+
@test call_on_device(f, x, y) isa Real broken = broken # Just check it doesn't error
153+
end
154+
broken = ispocl && T == Float16
155+
@test call_on_device(OpenCL.rootn, T(8.0), Int32(3)) T(2.0) broken = broken
156+
end
157+
158+
@testset "OpenCL-specific ternary - $T" for T in float_types
159+
x = rand(T)
160+
y = rand(T)
161+
z = rand(T)
162+
@test call_on_device(OpenCL.mad, x, y, z) x * y + z
163+
end
164+
165+
end
166+
52167
end

0 commit comments

Comments
 (0)