diff --git a/ynnpack/base/arithmetic.h b/ynnpack/base/arithmetic.h index 16bcdbbfbec..6c411ca7c00 100644 --- a/ynnpack/base/arithmetic.h +++ b/ynnpack/base/arithmetic.h @@ -233,6 +233,11 @@ T floor_log2(T a) { return static_cast(exp - 1); } +template +T exp2_round(T a) { + return std::ldexp(static_cast(1.0), static_cast(std::nearbyint(a))); +} + } // namespace ynn #endif // XNNPACK_YNNPACK_BASE_ARITHMETIC_H_ diff --git a/ynnpack/base/simd/arm_neon_base.h b/ynnpack/base/simd/arm_neon_base.h index 31041abb837..30d100a03cd 100644 --- a/ynnpack/base/simd/arm_neon_base.h +++ b/ynnpack/base/simd/arm_neon_base.h @@ -717,6 +717,21 @@ YNN_ALWAYS_INLINE f32x4 floor_log2(f32x4 a) { bias_383); return f32x4{vbslq_f32(is_inf, infinity, res)}; } +YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) { +#if defined(__ARM_ARCH) && __ARM_ARCH < 8 + const float result[] = { + ynn::exp2_round(vgetq_lane_f32(a.v, 0)), + ynn::exp2_round(vgetq_lane_f32(a.v, 1)), + ynn::exp2_round(vgetq_lane_f32(a.v, 2)), + ynn::exp2_round(vgetq_lane_f32(a.v, 3)), + }; + return f32x4{vld1q_f32(result)}; +#else + float32x4_t magic = vdupq_n_f32(127.0f + static_cast(1 << 23)); + int32x4_t res_bits = vreinterpretq_s32_f32(vaddq_f32(a.v, magic)); + return f32x4{vreinterpretq_f32_s32(vshlq_n_s32(res_bits, 23))}; +#endif +} #ifdef YNN_ARCH_ARM64 YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) { const double result[] = { @@ -725,6 +740,11 @@ YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) { }; return f64x2{vld1q_f64(result)}; } +YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) { + float64x2_t magic = vdupq_n_f64(1023.0 + static_cast(1ll << 52)); + int64x2_t res_bits = vreinterpretq_s64_f64(vaddq_f64(a.v, magic)); + return f64x2{vreinterpretq_f64_s64(vshlq_n_s64(res_bits, 52))}; +} #endif namespace internal { @@ -750,7 +770,7 @@ YNN_ALWAYS_INLINE float32x4_t not_f32(float32x4_t a) { YNN_ALWAYS_INLINE f32x4 floor(f32x4 a) { #if defined(__ARM_ARCH) && __ARM_ARCH < 8 - float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f); + float32x4_t max_non_int_val = vdupq_n_f32(static_cast(1 << 23)); uint32x4_t use_rounding = vcaltq_f32(a.v, max_non_int_val); float32x4_t trunc = vcvtq_f32_s32(vcvtq_s32_f32(a.v)); uint32x4_t floor_mask = vcgtq_f32(trunc, a.v); @@ -768,7 +788,7 @@ YNN_ALWAYS_INLINE f64x2 floor(f64x2 a) { return f64x2{vrndmq_f64(a.v)}; } YNN_ALWAYS_INLINE f32x4 ceil(f32x4 a) { #if defined(__ARM_ARCH) && __ARM_ARCH < 8 - float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f); + float32x4_t max_non_int_val = vdupq_n_f32(static_cast(1 << 23)); uint32x4_t use_rounding = vcaltq_f32(a.v, max_non_int_val); float32x4_t trunc = vcvtq_f32_s32(vcvtq_s32_f32(a.v)); uint32x4_t ceil_mask = vcltq_f32(trunc, a.v); @@ -786,7 +806,7 @@ YNN_ALWAYS_INLINE f64x2 ceil(f64x2 a) { return f64x2{vrndpq_f64(a.v)}; } YNN_ALWAYS_INLINE f32x4 round(f32x4 a) { #if defined(__ARM_ARCH) && __ARM_ARCH < 8 - float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f); + float32x4_t max_non_int_val = vdupq_n_f32(static_cast(1 << 23)); float32x4_t filter = vreinterpretq_f32_u32(vcaltq_f32(a.v, max_non_int_val)); float32x4_t half = vdupq_n_f32(0.5f); float32x4_t sign_mask = vdupq_n_f32(-0.0f); diff --git a/ynnpack/base/simd/bench/arm64_neon.cc b/ynnpack/base/simd/bench/arm64_neon.cc index 649303a6cd5..399211e8fa0 100644 --- a/ynnpack/base/simd/bench/arm64_neon.cc +++ b/ynnpack/base/simd/bench/arm64_neon.cc @@ -13,6 +13,7 @@ namespace ynn { namespace simd { BENCH_UNARY(neon, floor_log2, f64, 2); +BENCH_UNARY(neon, exp2_round, f64, 2); } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/bench/arm_neon.cc b/ynnpack/base/simd/bench/arm_neon.cc index 3c39f2f512f..3f9127eb2ba 100644 --- a/ynnpack/base/simd/bench/arm_neon.cc +++ b/ynnpack/base/simd/bench/arm_neon.cc @@ -18,6 +18,7 @@ BENCH_PARTIAL_LOAD_STORE(neon, s16, 8); BENCH_PARTIAL_LOAD_STORE(neon, s32, 4); BENCH_UNARY(neon, floor_log2, f32, 4); +BENCH_UNARY(neon, exp2_round, f32, 4); } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/bench/generic.h b/ynnpack/base/simd/bench/generic.h index 337638682a2..c53f1fb8b41 100644 --- a/ynnpack/base/simd/bench/generic.h +++ b/ynnpack/base/simd/bench/generic.h @@ -146,10 +146,10 @@ static void BM_unary(benchmark::State& state, Fn fn) { } } -#define BENCH_UNARY(arch, op, type, N) \ - void BM_##op##_##type##x##N##_##arch(benchmark::State& state) { \ - BM_unary(state, [](vec a) { return floor_log2(a); }); \ - } \ +#define BENCH_UNARY(arch, op, type, N) \ + void BM_##op##_##type##x##N##_##arch(benchmark::State& state) { \ + BM_unary(state, [](vec a) { return op(a); }); \ + } \ BENCHMARK(BM_##op##_##type##x##N##_##arch); } // namespace simd diff --git a/ynnpack/base/simd/bench/x86_avx2.cc b/ynnpack/base/simd/bench/x86_avx2.cc index 9b7efcf62ab..6585b73c14f 100644 --- a/ynnpack/base/simd/bench/x86_avx2.cc +++ b/ynnpack/base/simd/bench/x86_avx2.cc @@ -12,6 +12,8 @@ namespace simd { BENCH_UNARY(avx2, floor_log2, f32, 8); BENCH_UNARY(avx2, floor_log2, f64, 4); +BENCH_UNARY(avx2, exp2_round, f32, 8); +BENCH_UNARY(avx2, exp2_round, f64, 4); } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/bench/x86_avx512.cc b/ynnpack/base/simd/bench/x86_avx512.cc index 130f2b9d20a..f247cf3a902 100644 --- a/ynnpack/base/simd/bench/x86_avx512.cc +++ b/ynnpack/base/simd/bench/x86_avx512.cc @@ -31,6 +31,8 @@ BENCH_UNARY(avx512, floor_log2, f32, 4); BENCH_UNARY(avx512, floor_log2, f64, 8); BENCH_UNARY(avx512, floor_log2, f64, 4); BENCH_UNARY(avx512, floor_log2, f64, 2); +BENCH_UNARY(avx512, exp2_round, f32, 16); +BENCH_UNARY(avx512, exp2_round, f64, 8); } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/bench/x86_sse2.cc b/ynnpack/base/simd/bench/x86_sse2.cc index fa32725253b..b367d105781 100644 --- a/ynnpack/base/simd/bench/x86_sse2.cc +++ b/ynnpack/base/simd/bench/x86_sse2.cc @@ -20,6 +20,8 @@ BENCH_FMA(sse2, f32, 4); BENCH_UNARY(sse2, floor_log2, f32, 4); BENCH_UNARY(sse2, floor_log2, f64, 2); +BENCH_UNARY(sse2, exp2_round, f32, 4); +BENCH_UNARY(sse2, exp2_round, f64, 2); } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/generic.inc b/ynnpack/base/simd/generic.inc index 687434497ba..920953816d3 100644 --- a/ynnpack/base/simd/generic.inc +++ b/ynnpack/base/simd/generic.inc @@ -210,6 +210,11 @@ template YNN_ALWAYS_INLINE vec floor_log2(vec a) { return {floor_log2(a.lo()), floor_log2(a.hi())}; } + +template +YNN_ALWAYS_INLINE vec exp2_round(vec a) { + return {exp2_round(a.lo()), exp2_round(a.hi())}; +} template YNN_ALWAYS_INLINE vec ceil(vec a) { return {ceil(a.lo()), ceil(a.hi())}; diff --git a/ynnpack/base/simd/test/arm64_neon.cc b/ynnpack/base/simd/test/arm64_neon.cc index a01621712ce..7b27aa8ba86 100644 --- a/ynnpack/base/simd/test/arm64_neon.cc +++ b/ynnpack/base/simd/test/arm64_neon.cc @@ -39,6 +39,7 @@ TEST_ROUND(arm_neon, f64, 2); TEST_SQRT(arm_neon, f64, 2); TEST_ABS(arm_neon, f64, 2); TEST_FLOOR_LOG2(arm_neon, f64, 2); +TEST_EXP2_ROUND(arm_neon, f64, 2); TEST_HORIZONTAL_MIN(arm_neon, f64, 2); TEST_HORIZONTAL_MAX(arm_neon, f64, 2); diff --git a/ynnpack/base/simd/test/arm_neon.cc b/ynnpack/base/simd/test/arm_neon.cc index d7708db68ae..0ccecf8b212 100644 --- a/ynnpack/base/simd/test/arm_neon.cc +++ b/ynnpack/base/simd/test/arm_neon.cc @@ -144,6 +144,7 @@ TEST_ABS(arm_neon, s32, 4); TEST_ABS(arm_neon, f32, 4); TEST_FLOOR_LOG2(arm_neon, f32, 4); +TEST_EXP2_ROUND(arm_neon, f32, 4); TEST_HORIZONTAL_MIN(arm_neon, u8, 16); TEST_HORIZONTAL_MIN(arm_neon, s8, 16); diff --git a/ynnpack/base/simd/test/generic.h b/ynnpack/base/simd/test/generic.h index 51cfa3365ac..1695f2d322b 100644 --- a/ynnpack/base/simd/test/generic.h +++ b/ynnpack/base/simd/test/generic.h @@ -441,6 +441,32 @@ void test_floor_log2() { #define TEST_FLOOR_LOG2(test_class, type, N) \ TEST_F(test_class, floor_log2_##type##x##N) { test_floor_log2(); } +template +void test_exp2_round() { + using vector = vec; + + ReplicableRandomDevice rng; + for (auto _ : FuzzTest(std::chrono::milliseconds(100))) { + scalar a[vector::N]; + auto abs_max = std::log2(type_info::max()) - 2; + fill_random(a, vector::N, rng, -abs_max, abs_max); + + scalar result[vector::N]; + store(result, exp2_round(load(a, vector::N))); + + for (size_t i = 0; i < vector::N; ++i) { + // Allow any integer k such that |k - a| <= 0.5. + // This accommodates different tie-breaking behaviors. + auto k = std::log2(result[i]); + ASSERT_NEAR(k, a[i], 0.5001) << a[i]; + ASSERT_EQ(k, std::nearbyint(k)) << a[i]; + } + } +} + +#define TEST_EXP2_ROUND(test_class, type, N) \ + TEST_F(test_class, exp2_round_##type##x##N) { test_exp2_round(); } + struct min_op { template T operator()(T a, T b) { diff --git a/ynnpack/base/simd/test/wasm_simd128.cc b/ynnpack/base/simd/test/wasm_simd128.cc index b8904e00979..e20b2c9564a 100644 --- a/ynnpack/base/simd/test/wasm_simd128.cc +++ b/ynnpack/base/simd/test/wasm_simd128.cc @@ -120,6 +120,10 @@ TEST_FLOOR(wasm_simd128, f32, 4); TEST_CEIL(wasm_simd128, f32, 4); TEST_ROUND(wasm_simd128, f32, 4); TEST_SQRT(wasm_simd128, f32, 4); +TEST_FLOOR_LOG2(wasm_simd128, f32, 4); +TEST_FLOOR_LOG2(wasm_simd128, f64, 2); +TEST_EXP2_ROUND(wasm_simd128, f32, 4); +TEST_EXP2_ROUND(wasm_simd128, f64, 2); TEST_ABS(wasm_simd128, s8, 16); TEST_ABS(wasm_simd128, s16, 8); diff --git a/ynnpack/base/simd/test/x86_avx2.cc b/ynnpack/base/simd/test/x86_avx2.cc index 9b6749c097d..c6f4ca7ed34 100644 --- a/ynnpack/base/simd/test/x86_avx2.cc +++ b/ynnpack/base/simd/test/x86_avx2.cc @@ -59,6 +59,8 @@ TEST_ABS(x86_avx2, s32, 8); TEST_FLOOR_LOG2(x86_avx2, f32, 8); TEST_FLOOR_LOG2(x86_avx2, f64, 4); +TEST_EXP2_ROUND(x86_avx2, f32, 8); +TEST_EXP2_ROUND(x86_avx2, f64, 4); TEST_HORIZONTAL_MIN(x86_avx2, u8, 32); TEST_HORIZONTAL_MIN(x86_avx2, s8, 32); diff --git a/ynnpack/base/simd/test/x86_avx512.cc b/ynnpack/base/simd/test/x86_avx512.cc index 7fef12e3fc6..dc061dc0efb 100644 --- a/ynnpack/base/simd/test/x86_avx512.cc +++ b/ynnpack/base/simd/test/x86_avx512.cc @@ -145,6 +145,8 @@ TEST_FLOOR_LOG2(x86_avx512, f32, 4); TEST_FLOOR_LOG2(x86_avx512, f64, 8); TEST_FLOOR_LOG2(x86_avx512, f64, 4); TEST_FLOOR_LOG2(x86_avx512, f64, 2); +TEST_EXP2_ROUND(x86_avx512, f32, 16); +TEST_EXP2_ROUND(x86_avx512, f64, 8); TEST_FLOOR(x86_avx512, f32, 16); TEST_FLOOR(x86_avx512, f64, 8); diff --git a/ynnpack/base/simd/test/x86_sse2.cc b/ynnpack/base/simd/test/x86_sse2.cc index 5165c03ba86..0e5fa4aa8f4 100644 --- a/ynnpack/base/simd/test/x86_sse2.cc +++ b/ynnpack/base/simd/test/x86_sse2.cc @@ -116,6 +116,8 @@ TEST_ABS(x86_sse2, f32, 4); TEST_FLOOR_LOG2(x86_sse2, f32, 4); TEST_FLOOR_LOG2(x86_sse2, f64, 2); +TEST_EXP2_ROUND(x86_sse2, f32, 4); +TEST_EXP2_ROUND(x86_sse2, f64, 2); TEST_HORIZONTAL_MIN(x86_sse2, u8, 16); TEST_HORIZONTAL_MIN(x86_sse2, s16, 8); diff --git a/ynnpack/base/simd/vec.h b/ynnpack/base/simd/vec.h index a17749d5bb2..4efa41b1606 100644 --- a/ynnpack/base/simd/vec.h +++ b/ynnpack/base/simd/vec.h @@ -130,6 +130,9 @@ template vec floor(vec a); template vec floor_log2(vec a); +// This function is permitted to use any rounding mode. +template +vec exp2_round(vec a); template vec ceil(vec a); template @@ -317,7 +320,11 @@ YNN_ALWAYS_INLINE vec sub_sat(vec a, vec b) { } template YNN_ALWAYS_INLINE vec floor_log2(vec a) { - return vec{floor_log2(a.v)}; + return vec{ynn::floor_log2(a.v)}; +} +template +YNN_ALWAYS_INLINE vec exp2_round(vec a) { + return vec{ynn::exp2_round(a.v)}; } template diff --git a/ynnpack/base/simd/wasm_simd128.h b/ynnpack/base/simd/wasm_simd128.h index 3bf66d5709e..b607fbb5128 100644 --- a/ynnpack/base/simd/wasm_simd128.h +++ b/ynnpack/base/simd/wasm_simd128.h @@ -524,6 +524,44 @@ YNN_ALWAYS_INLINE f32x4 round(f32x4 a) { } YNN_ALWAYS_INLINE f32x4 sqrt(f32x4 a) { return f32x4{wasm_f32x4_sqrt(a.v)}; } +YNN_ALWAYS_INLINE f32x4 floor_log2(f32x4 a) { + const v128_t sign_mask = wasm_f32x4_splat(-0.0f); + const v128_t is_zero = wasm_f32x4_eq(a.v, wasm_f32x4_splat(0.0f)); + a.v = wasm_v128_or(wasm_v128_and(is_zero, sign_mask), a.v); + + const v128_t sign_and_exp_mask = wasm_i32x4_splat(0xFF800000); + v128_t exp = wasm_v128_and(a.v, sign_and_exp_mask); + + const v128_t infinity = + wasm_f32x4_splat(std::numeric_limits::infinity()); + const v128_t is_inf = wasm_f32x4_eq(a.v, infinity); + + exp = wasm_i32x4_shr(exp, 8); + + const v128_t bias_256 = wasm_f32x4_splat(256.0f); + const v128_t bias_383 = wasm_f32x4_splat(383.0f); + const v128_t res = wasm_f32x4_sub(wasm_v128_or(bias_256, exp), bias_383); + return f32x4{wasm_v128_bitselect(infinity, res, is_inf)}; +} + +YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) { + return f64x2{wasm_f64x2_make( + ynn::floor_log2(wasm_f64x2_extract_lane(a.v, 0)), + ynn::floor_log2(wasm_f64x2_extract_lane(a.v, 1)))}; +} + +YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) { + const v128_t magic = wasm_f32x4_splat(127.0f + static_cast(1 << 23)); + const v128_t res_bits = wasm_f32x4_add(a.v, magic); + return f32x4{wasm_i32x4_shl(res_bits, 23)}; +} + +YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) { + return f64x2{wasm_f64x2_make( + ynn::exp2_round(wasm_f64x2_extract_lane(a.v, 0)), + ynn::exp2_round(wasm_f64x2_extract_lane(a.v, 1)))}; +} + YNN_ALWAYS_INLINE s16x16 cast(s8x16 a, int16_t) { return {s16x8{wasm_i16x8_extend_low_i8x16(a.v)}, s16x8{wasm_i16x8_extend_high_i8x16(a.v)}}; diff --git a/ynnpack/base/simd/x86_avx2_base.h b/ynnpack/base/simd/x86_avx2_base.h index 59c597f446b..c810cc9bd14 100644 --- a/ynnpack/base/simd/x86_avx2_base.h +++ b/ynnpack/base/simd/x86_avx2_base.h @@ -165,6 +165,19 @@ YNN_ALWAYS_INLINE f32x8 cast(bf16x8 a, float) { 16))}; } +YNN_ALWAYS_INLINE f32x8 exp2_round(f32x8 a) { + const __m256 magic = _mm256_set1_ps(127.0f + static_cast(1 << 23)); + const __m256 res_bits = _mm256_add_ps(a.v, magic); + return f32x8{_mm256_castsi256_ps( + _mm256_slli_epi32(_mm256_castps_si256(res_bits), 23))}; +} +YNN_ALWAYS_INLINE f64x4 exp2_round(f64x4 a) { + const __m256d magic = _mm256_set1_pd(1023.0 + static_cast(1ll << 52)); + const __m256d res_bits = _mm256_add_pd(a.v, magic); + return f64x4{_mm256_castsi256_pd( + _mm256_slli_epi64(_mm256_castpd_si256(res_bits), 52))}; +} + } // namespace simd } // namespace ynn diff --git a/ynnpack/base/simd/x86_avx512.h b/ynnpack/base/simd/x86_avx512.h index c4c5f88d0ae..debe80c035c 100644 --- a/ynnpack/base/simd/x86_avx512.h +++ b/ynnpack/base/simd/x86_avx512.h @@ -956,6 +956,19 @@ YNN_ALWAYS_INLINE f64x8 floor_log2(f64x8 a) { negative, res, _mm512_set1_pd(std::numeric_limits::quiet_NaN()))}; } +YNN_ALWAYS_INLINE f32x16 exp2_round(f32x16 a) { + const __m512 magic = _mm512_set1_ps(127.0f + static_cast(1 << 23)); + const __m512 res_bits = _mm512_add_ps(a.v, magic); + return f32x16{_mm512_castsi512_ps( + _mm512_slli_epi32(_mm512_castps_si512(res_bits), 23))}; +} +YNN_ALWAYS_INLINE f64x8 exp2_round(f64x8 a) { + const __m512d magic = _mm512_set1_pd(1023.0 + static_cast(1ll << 52)); + const __m512d res_bits = _mm512_add_pd(a.v, magic); + return f64x8{_mm512_castsi512_pd( + _mm512_slli_epi64(_mm512_castpd_si512(res_bits), 52))}; +} + template YNN_ALWAYS_INLINE s32x4 extract(s32x16 x, decltype(s32x4::N)) { return s32x4{_mm512_extracti32x4_epi32(x.v, Index)}; diff --git a/ynnpack/base/simd/x86_sse2_base.h b/ynnpack/base/simd/x86_sse2_base.h index 7f6aa4fdb55..0e17bebc2aa 100644 --- a/ynnpack/base/simd/x86_sse2_base.h +++ b/ynnpack/base/simd/x86_sse2_base.h @@ -505,6 +505,19 @@ YNN_ALWAYS_INLINE f64x2 abs(f64x2 a) { return f64x2{_mm_andnot_pd(_mm_set1_pd(-0.0), a.v)}; } +YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) { + const __m128 magic = _mm_set1_ps(127.0 + static_cast(1 << 23)); + const __m128 res_bits = _mm_add_ps(a.v, magic); + return f32x4{ + _mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(res_bits), 23))}; +} +YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) { + const __m128d magic = _mm_set1_pd(1023.0 + static_cast(1ll << 52)); + const __m128d res_bits = _mm_add_pd(a.v, magic); + return f64x2{ + _mm_castsi128_pd(_mm_slli_epi64(_mm_castpd_si128(res_bits), 52))}; +} + YNN_ALWAYS_INLINE double horizontal_max(f64x2 a) { return _mm_cvtsd_f64(_mm_max_sd(a.v, _mm_shuffle_pd(a.v, a.v, 1))); } diff --git a/ynnpack/kernels/elementwise/compiler.py b/ynnpack/kernels/elementwise/compiler.py index 73062d4ee25..9a17bbc09e5 100644 --- a/ynnpack/kernels/elementwise/compiler.py +++ b/ynnpack/kernels/elementwise/compiler.py @@ -433,6 +433,11 @@ def floor_log2(value): return Op(value.ty, "floor_log2", [value]) +@intrinsic +def exp2_round(value): + return Op(value.ty, "exp2_round", [value]) + + @intrinsic def sqrt(value): return Op(value.ty, "sqrt", [value]) @@ -976,6 +981,7 @@ def __init__(self): "round", "floor", "floor_log2", + "exp2_round", "ceil", "sqrt", "reinterpret_cast", diff --git a/ynnpack/kernels/unary/exp.py b/ynnpack/kernels/unary/exp.py index ff5fc651081..de65590c402 100644 --- a/ynnpack/kernels/unary/exp.py +++ b/ynnpack/kernels/unary/exp.py @@ -5,25 +5,11 @@ from ynnpack.kernels.elementwise.compiler import * # pylint: disable=wildcard-import -def setexp_f32(x): - # If `x` is an floating point value in the range [-127, 128], then - # `(x + magic) << 23` will generate the floating point value corresponding - # to `2^round(x)` (2^-127 and 2^128 will flush to zero and infinity, - # respectively). - vmagic = 8388735.0 - return reinterpret_cast( - Float(32), - logical_shift_left(reinterpret_cast(Int(32), x + vmagic), i32(23)), - ) - - -# Quick-and-dirty round to nearest, only works for floats in the range -# `[2^-22, 2^22)`. def qd_round_f32(a): # If `x` is an floating point value in the range `[2^-22, 2^22)`, then # `(x + magic) - magic`` will generate the floating point value corresponding # to `round(x)`. - vmagic = 12582912.0 + vmagic = 1.5*(2**23) return (vmagic + a) - vmagic @@ -54,7 +40,7 @@ def exp_fp32(a, x, output_multiplier, input_multiplier): vr = vz_prime - vz # Compute 2^z. - v2z = setexp_f32(vz) + v2z = exp2_round(vz) # Evaluate the numerator polynomial p(f). vp = multiply_add(vr, valpha_3, valpha_2)