|
| 1 | +module SPIRVIntrinsicsSIMDExt |
| 2 | + |
| 3 | +using SPIRVIntrinsics |
| 4 | +using SPIRVIntrinsics: @device_override, @device_function, @builtin_ccall, @typed_ccall |
| 5 | +using SIMD |
| 6 | +import SpecialFunctions |
| 7 | + |
| 8 | +const known_intrinsics = String[] |
| 9 | + |
| 10 | +# Generate vectorized math intrinsics |
| 11 | +for N in [2, 3, 4, 8, 16], T in [Float16, Float32, Float64] |
| 12 | + VT = :(Vec{$N,$T}) |
| 13 | + LVT = :(SIMD.LVec{$N,$T}) |
| 14 | + |
| 15 | + @eval begin |
| 16 | + # Unary operations |
| 17 | + @device_override @inline Base.acos(x::$VT) = $VT(@builtin_ccall("acos", $LVT, ($LVT,), x.data)) |
| 18 | + @device_override @inline Base.acosh(x::$VT) = $VT(@builtin_ccall("acosh", $LVT, ($LVT,), x.data)) |
| 19 | + @device_function @inline SPIRVIntrinsics.acospi(x::$VT) = $VT(@builtin_ccall("acospi", $LVT, ($LVT,), x.data)) |
| 20 | + |
| 21 | + @device_override @inline Base.asin(x::$VT) = $VT(@builtin_ccall("asin", $LVT, ($LVT,), x.data)) |
| 22 | + @device_override @inline Base.asinh(x::$VT) = $VT(@builtin_ccall("asinh", $LVT, ($LVT,), x.data)) |
| 23 | + @device_function @inline SPIRVIntrinsics.asinpi(x::$VT) = $VT(@builtin_ccall("asinpi", $LVT, ($LVT,), x.data)) |
| 24 | + |
| 25 | + @device_override @inline Base.atan(x::$VT) = $VT(@builtin_ccall("atan", $LVT, ($LVT,), x.data)) |
| 26 | + @device_override @inline Base.atanh(x::$VT) = $VT(@builtin_ccall("atanh", $LVT, ($LVT,), x.data)) |
| 27 | + @device_function @inline SPIRVIntrinsics.atanpi(x::$VT) = $VT(@builtin_ccall("atanpi", $LVT, ($LVT,), x.data)) |
| 28 | + |
| 29 | + @device_override @inline Base.cbrt(x::$VT) = $VT(@builtin_ccall("cbrt", $LVT, ($LVT,), x.data)) |
| 30 | + @device_override @inline Base.ceil(x::$VT) = $VT(@builtin_ccall("ceil", $LVT, ($LVT,), x.data)) |
| 31 | + |
| 32 | + @device_override @inline Base.cos(x::$VT) = $VT(@builtin_ccall("cos", $LVT, ($LVT,), x.data)) |
| 33 | + @device_override @inline Base.cosh(x::$VT) = $VT(@builtin_ccall("cosh", $LVT, ($LVT,), x.data)) |
| 34 | + @device_override @inline Base.cospi(x::$VT) = $VT(@builtin_ccall("cospi", $LVT, ($LVT,), x.data)) |
| 35 | + |
| 36 | + @device_override @inline SpecialFunctions.erfc(x::$VT) = $VT(@builtin_ccall("erfc", $LVT, ($LVT,), x.data)) |
| 37 | + @device_override @inline SpecialFunctions.erf(x::$VT) = $VT(@builtin_ccall("erf", $LVT, ($LVT,), x.data)) |
| 38 | + |
| 39 | + @device_override @inline Base.exp(x::$VT) = $VT(@builtin_ccall("exp", $LVT, ($LVT,), x.data)) |
| 40 | + @device_override @inline Base.exp2(x::$VT) = $VT(@builtin_ccall("exp2", $LVT, ($LVT,), x.data)) |
| 41 | + @device_override @inline Base.exp10(x::$VT) = $VT(@builtin_ccall("exp10", $LVT, ($LVT,), x.data)) |
| 42 | + @device_override @inline Base.expm1(x::$VT) = $VT(@builtin_ccall("expm1", $LVT, ($LVT,), x.data)) |
| 43 | + |
| 44 | + @device_override @inline Base.abs(x::$VT) = $VT(@builtin_ccall("fabs", $LVT, ($LVT,), x.data)) |
| 45 | + @device_override @inline Base.floor(x::$VT) = $VT(@builtin_ccall("floor", $LVT, ($LVT,), x.data)) |
| 46 | + |
| 47 | + @device_override @inline SpecialFunctions.loggamma(x::$VT) = $VT(@builtin_ccall("lgamma", $LVT, ($LVT,), x.data)) |
| 48 | + |
| 49 | + @device_override @inline Base.log(x::$VT) = $VT(@builtin_ccall("log", $LVT, ($LVT,), x.data)) |
| 50 | + @device_override @inline Base.log2(x::$VT) = $VT(@builtin_ccall("log2", $LVT, ($LVT,), x.data)) |
| 51 | + @device_override @inline Base.log10(x::$VT) = $VT(@builtin_ccall("log10", $LVT, ($LVT,), x.data)) |
| 52 | + @device_override @inline Base.log1p(x::$VT) = $VT(@builtin_ccall("log1p", $LVT, ($LVT,), x.data)) |
| 53 | + @device_function @inline SPIRVIntrinsics.logb(x::$VT) = $VT(@builtin_ccall("logb", $LVT, ($LVT,), x.data)) |
| 54 | + |
| 55 | + @device_function @inline SPIRVIntrinsics.rint(x::$VT) = $VT(@builtin_ccall("rint", $LVT, ($LVT,), x.data)) |
| 56 | + @device_override @inline Base.round(x::$VT) = $VT(@builtin_ccall("round", $LVT, ($LVT,), x.data)) |
| 57 | + @device_function @inline SPIRVIntrinsics.rsqrt(x::$VT) = $VT(@builtin_ccall("rsqrt", $LVT, ($LVT,), x.data)) |
| 58 | + |
| 59 | + @device_override @inline Base.sin(x::$VT) = $VT(@builtin_ccall("sin", $LVT, ($LVT,), x.data)) |
| 60 | + @device_override @inline Base.sinh(x::$VT) = $VT(@builtin_ccall("sinh", $LVT, ($LVT,), x.data)) |
| 61 | + @device_override @inline Base.sinpi(x::$VT) = $VT(@builtin_ccall("sinpi", $LVT, ($LVT,), x.data)) |
| 62 | + |
| 63 | + @device_override @inline Base.sqrt(x::$VT) = $VT(@builtin_ccall("sqrt", $LVT, ($LVT,), x.data)) |
| 64 | + |
| 65 | + @device_override @inline Base.tan(x::$VT) = $VT(@builtin_ccall("tan", $LVT, ($LVT,), x.data)) |
| 66 | + @device_override @inline Base.tanh(x::$VT) = $VT(@builtin_ccall("tanh", $LVT, ($LVT,), x.data)) |
| 67 | + @device_override @inline Base.tanpi(x::$VT) = $VT(@builtin_ccall("tanpi", $LVT, ($LVT,), x.data)) |
| 68 | + |
| 69 | + @device_override @inline SpecialFunctions.gamma(x::$VT) = $VT(@builtin_ccall("tgamma", $LVT, ($LVT,), x.data)) |
| 70 | + |
| 71 | + @device_override @inline Base.trunc(x::$VT) = $VT(@builtin_ccall("trunc", $LVT, ($LVT,), x.data)) |
| 72 | + |
| 73 | + # Binary operations |
| 74 | + @device_override @inline Base.atan(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2", $LVT, ($LVT, $LVT), y.data, x.data)) |
| 75 | + @device_function @inline SPIRVIntrinsics.atanpi(y::$VT, x::$VT) = $VT(@builtin_ccall("atan2pi", $LVT, ($LVT, $LVT), y.data, x.data)) |
| 76 | + |
| 77 | + @device_override @inline Base.copysign(x::$VT, y::$VT) = $VT(@builtin_ccall("copysign", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 78 | + @device_function @inline SPIRVIntrinsics.dim(x::$VT, y::$VT) = $VT(@builtin_ccall("fdim", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 79 | + |
| 80 | + @device_override @inline Base.hypot(x::$VT, y::$VT) = $VT(@builtin_ccall("hypot", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 81 | + |
| 82 | + @device_override @inline Base.max(x::$VT, y::$VT) = $VT(@builtin_ccall("fmax", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 83 | + @device_override @inline Base.min(x::$VT, y::$VT) = $VT(@builtin_ccall("fmin", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 84 | + |
| 85 | + @device_function @inline SPIRVIntrinsics.maxmag(x::$VT, y::$VT) = $VT(@builtin_ccall("maxmag", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 86 | + @device_function @inline SPIRVIntrinsics.minmag(x::$VT, y::$VT) = $VT(@builtin_ccall("minmag", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 87 | + |
| 88 | + @device_function @inline SPIRVIntrinsics.nextafter(x::$VT, y::$VT) = $VT(@builtin_ccall("nextafter", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 89 | + |
| 90 | + @device_override @inline Base.:(^)(x::$VT, y::$VT) = $VT(@builtin_ccall("pow", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 91 | + @device_function @inline SPIRVIntrinsics.powr(x::$VT, y::$VT) = $VT(@builtin_ccall("powr", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 92 | + |
| 93 | + @device_override @inline Base.rem(x::$VT, y::$VT) = $VT(@builtin_ccall("remainder", $LVT, ($LVT, $LVT), x.data, y.data)) |
| 94 | + |
| 95 | + # Ternary operations |
| 96 | + @device_override @inline Base.fma(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("fma", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data)) |
| 97 | + @device_function @inline SPIRVIntrinsics.mad(a::$VT, b::$VT, c::$VT) = $VT(@builtin_ccall("mad", $LVT, ($LVT, $LVT, $LVT), a.data, b.data, c.data)) |
| 98 | + end |
| 99 | + |
| 100 | + # Special operations with Int32 parameters |
| 101 | + VIntT = :(Vec{$N,Int32}) |
| 102 | + LVIntT = :(SIMD.LVec{$N,Int32}) |
| 103 | + |
| 104 | + @eval begin |
| 105 | + @device_function @inline SPIRVIntrinsics.ilogb(x::$VT) = $VIntT(@builtin_ccall("ilogb", $LVIntT, ($LVT,), x.data)) |
| 106 | + @device_override @inline Base.ldexp(x::$VT, k::$VIntT) = $VT(@builtin_ccall("ldexp", $LVT, ($LVT, $LVIntT), x.data, k.data)) |
| 107 | + @device_override @inline Base.:(^)(x::$VT, y::$VIntT) = $VT(@builtin_ccall("pown", $LVT, ($LVT, $LVIntT), x.data, y.data)) |
| 108 | + @device_function @inline SPIRVIntrinsics.rootn(x::$VT, y::$VIntT) = $VT(@builtin_ccall("rootn", $LVT, ($LVT, $LVIntT), x.data, y.data)) |
| 109 | + end |
| 110 | +end |
| 111 | + |
| 112 | +# nan functions - take unsigned integer codes and return floats |
| 113 | +for N in [2, 3, 4, 8, 16] |
| 114 | + @eval begin |
| 115 | + @device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt16}) = Vec{$N,Float16}(@builtin_ccall("nan", SIMD.LVec{$N,Float16}, (SIMD.LVec{$N,UInt16},), nancode.data)) |
| 116 | + @device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt32}) = Vec{$N,Float32}(@builtin_ccall("nan", SIMD.LVec{$N,Float32}, (SIMD.LVec{$N,UInt32},), nancode.data)) |
| 117 | + @device_function @inline SPIRVIntrinsics.nan(nancode::Vec{$N,UInt64}) = Vec{$N,Float64}(@builtin_ccall("nan", SIMD.LVec{$N,Float64}, (SIMD.LVec{$N,UInt64},), nancode.data)) |
| 118 | + end |
| 119 | +end |
| 120 | + |
| 121 | +end # module |
0 commit comments