Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ynnpack/base/arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,11 @@ T floor_log2(T a) {
return static_cast<T>(exp - 1);
}

template <typename T>
T exp2_round(T a) {
return std::ldexp(static_cast<T>(1.0), static_cast<int>(std::nearbyint(a)));
}

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_ARITHMETIC_H_
26 changes: 23 additions & 3 deletions ynnpack/base/simd/arm_neon_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,21 @@ YNN_ALWAYS_INLINE f32x4 floor_log2(f32x4 a) {
bias_383);
return f32x4{vbslq_f32(is_inf, infinity, res)};
}
YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
const float result[] = {
ynn::exp2_round(vgetq_lane_f32(a.v, 0)),
ynn::exp2_round(vgetq_lane_f32(a.v, 1)),
ynn::exp2_round(vgetq_lane_f32(a.v, 2)),
ynn::exp2_round(vgetq_lane_f32(a.v, 3)),
};
return f32x4{vld1q_f32(result)};
#else
float32x4_t magic = vdupq_n_f32(127.0f + static_cast<float>(1 << 23));
int32x4_t res_bits = vreinterpretq_s32_f32(vaddq_f32(a.v, magic));
return f32x4{vreinterpretq_f32_s32(vshlq_n_s32(res_bits, 23))};
#endif
}
#ifdef YNN_ARCH_ARM64
YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) {
const double result[] = {
Expand All @@ -725,6 +740,11 @@ YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) {
};
return f64x2{vld1q_f64(result)};
}
YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) {
float64x2_t magic = vdupq_n_f64(1023.0 + static_cast<double>(1ll << 52));
int64x2_t res_bits = vreinterpretq_s64_f64(vaddq_f64(a.v, magic));
return f64x2{vreinterpretq_f64_s64(vshlq_n_s64(res_bits, 52))};
}
#endif

namespace internal {
Expand All @@ -750,7 +770,7 @@ YNN_ALWAYS_INLINE float32x4_t not_f32(float32x4_t a) {

YNN_ALWAYS_INLINE f32x4 floor(f32x4 a) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f);
float32x4_t max_non_int_val = vdupq_n_f32(static_cast<float>(1 << 23));
uint32x4_t use_rounding = vcaltq_f32(a.v, max_non_int_val);
float32x4_t trunc = vcvtq_f32_s32(vcvtq_s32_f32(a.v));
uint32x4_t floor_mask = vcgtq_f32(trunc, a.v);
Expand All @@ -768,7 +788,7 @@ YNN_ALWAYS_INLINE f64x2 floor(f64x2 a) { return f64x2{vrndmq_f64(a.v)}; }

YNN_ALWAYS_INLINE f32x4 ceil(f32x4 a) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f);
float32x4_t max_non_int_val = vdupq_n_f32(static_cast<float>(1 << 23));
uint32x4_t use_rounding = vcaltq_f32(a.v, max_non_int_val);
float32x4_t trunc = vcvtq_f32_s32(vcvtq_s32_f32(a.v));
uint32x4_t ceil_mask = vcltq_f32(trunc, a.v);
Expand All @@ -786,7 +806,7 @@ YNN_ALWAYS_INLINE f64x2 ceil(f64x2 a) { return f64x2{vrndpq_f64(a.v)}; }

YNN_ALWAYS_INLINE f32x4 round(f32x4 a) {
#if defined(__ARM_ARCH) && __ARM_ARCH < 8
float32x4_t max_non_int_val = vdupq_n_f32(8388608.0f);
float32x4_t max_non_int_val = vdupq_n_f32(static_cast<float>(1 << 23));
float32x4_t filter = vreinterpretq_f32_u32(vcaltq_f32(a.v, max_non_int_val));
float32x4_t half = vdupq_n_f32(0.5f);
float32x4_t sign_mask = vdupq_n_f32(-0.0f);
Expand Down
1 change: 1 addition & 0 deletions ynnpack/base/simd/bench/arm64_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace ynn {
namespace simd {

BENCH_UNARY(neon, floor_log2, f64, 2);
BENCH_UNARY(neon, exp2_round, f64, 2);

} // namespace simd
} // namespace ynn
1 change: 1 addition & 0 deletions ynnpack/base/simd/bench/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ BENCH_PARTIAL_LOAD_STORE(neon, s16, 8);
BENCH_PARTIAL_LOAD_STORE(neon, s32, 4);

BENCH_UNARY(neon, floor_log2, f32, 4);
BENCH_UNARY(neon, exp2_round, f32, 4);

} // namespace simd
} // namespace ynn
8 changes: 4 additions & 4 deletions ynnpack/base/simd/bench/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ static void BM_unary(benchmark::State& state, Fn fn) {
}
}

#define BENCH_UNARY(arch, op, type, N) \
void BM_##op##_##type##x##N##_##arch(benchmark::State& state) { \
BM_unary<type, N>(state, [](vec<type, N> a) { return floor_log2(a); }); \
} \
#define BENCH_UNARY(arch, op, type, N) \
void BM_##op##_##type##x##N##_##arch(benchmark::State& state) { \
BM_unary<type, N>(state, [](vec<type, N> a) { return op(a); }); \
} \
BENCHMARK(BM_##op##_##type##x##N##_##arch);

} // namespace simd
Expand Down
2 changes: 2 additions & 0 deletions ynnpack/base/simd/bench/x86_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace simd {

BENCH_UNARY(avx2, floor_log2, f32, 8);
BENCH_UNARY(avx2, floor_log2, f64, 4);
BENCH_UNARY(avx2, exp2_round, f32, 8);
BENCH_UNARY(avx2, exp2_round, f64, 4);

} // namespace simd
} // namespace ynn
2 changes: 2 additions & 0 deletions ynnpack/base/simd/bench/x86_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ BENCH_UNARY(avx512, floor_log2, f32, 4);
BENCH_UNARY(avx512, floor_log2, f64, 8);
BENCH_UNARY(avx512, floor_log2, f64, 4);
BENCH_UNARY(avx512, floor_log2, f64, 2);
BENCH_UNARY(avx512, exp2_round, f32, 16);
BENCH_UNARY(avx512, exp2_round, f64, 8);

} // namespace simd
} // namespace ynn
2 changes: 2 additions & 0 deletions ynnpack/base/simd/bench/x86_sse2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ BENCH_FMA(sse2, f32, 4);

BENCH_UNARY(sse2, floor_log2, f32, 4);
BENCH_UNARY(sse2, floor_log2, f64, 2);
BENCH_UNARY(sse2, exp2_round, f32, 4);
BENCH_UNARY(sse2, exp2_round, f64, 2);

} // namespace simd
} // namespace ynn
5 changes: 5 additions & 0 deletions ynnpack/base/simd/generic.inc
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> floor_log2(vec<T, N> a) {
return {floor_log2(a.lo()), floor_log2(a.hi())};
}

template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> exp2_round(vec<T, N> a) {
return {exp2_round(a.lo()), exp2_round(a.hi())};
}
template <typename T, size_t N>
YNN_ALWAYS_INLINE vec<T, N> ceil(vec<T, N> a) {
return {ceil(a.lo()), ceil(a.hi())};
Expand Down
1 change: 1 addition & 0 deletions ynnpack/base/simd/test/arm64_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ TEST_ROUND(arm_neon, f64, 2);
TEST_SQRT(arm_neon, f64, 2);
TEST_ABS(arm_neon, f64, 2);
TEST_FLOOR_LOG2(arm_neon, f64, 2);
TEST_EXP2_ROUND(arm_neon, f64, 2);

TEST_HORIZONTAL_MIN(arm_neon, f64, 2);
TEST_HORIZONTAL_MAX(arm_neon, f64, 2);
Expand Down
1 change: 1 addition & 0 deletions ynnpack/base/simd/test/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ TEST_ABS(arm_neon, s32, 4);
TEST_ABS(arm_neon, f32, 4);

TEST_FLOOR_LOG2(arm_neon, f32, 4);
TEST_EXP2_ROUND(arm_neon, f32, 4);

TEST_HORIZONTAL_MIN(arm_neon, u8, 16);
TEST_HORIZONTAL_MIN(arm_neon, s8, 16);
Expand Down
26 changes: 26 additions & 0 deletions ynnpack/base/simd/test/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,32 @@ void test_floor_log2() {
#define TEST_FLOOR_LOG2(test_class, type, N) \
TEST_F(test_class, floor_log2_##type##x##N) { test_floor_log2<type, N>(); }

template <typename scalar, size_t N>
void test_exp2_round() {
using vector = vec<scalar, N>;

ReplicableRandomDevice rng;
for (auto _ : FuzzTest(std::chrono::milliseconds(100))) {
scalar a[vector::N];
auto abs_max = std::log2(type_info<scalar>::max()) - 2;
fill_random(a, vector::N, rng, -abs_max, abs_max);

scalar result[vector::N];
store(result, exp2_round(load(a, vector::N)));

for (size_t i = 0; i < vector::N; ++i) {
// Allow any integer k such that |k - a| <= 0.5.
// This accommodates different tie-breaking behaviors.
auto k = std::log2(result[i]);
ASSERT_NEAR(k, a[i], 0.5001) << a[i];
ASSERT_EQ(k, std::nearbyint(k)) << a[i];
}
}
}

#define TEST_EXP2_ROUND(test_class, type, N) \
TEST_F(test_class, exp2_round_##type##x##N) { test_exp2_round<type, N>(); }

struct min_op {
template <typename T>
T operator()(T a, T b) {
Expand Down
4 changes: 4 additions & 0 deletions ynnpack/base/simd/test/wasm_simd128.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ TEST_FLOOR(wasm_simd128, f32, 4);
TEST_CEIL(wasm_simd128, f32, 4);
TEST_ROUND(wasm_simd128, f32, 4);
TEST_SQRT(wasm_simd128, f32, 4);
TEST_FLOOR_LOG2(wasm_simd128, f32, 4);
TEST_FLOOR_LOG2(wasm_simd128, f64, 2);
TEST_EXP2_ROUND(wasm_simd128, f32, 4);
TEST_EXP2_ROUND(wasm_simd128, f64, 2);

TEST_ABS(wasm_simd128, s8, 16);
TEST_ABS(wasm_simd128, s16, 8);
Expand Down
2 changes: 2 additions & 0 deletions ynnpack/base/simd/test/x86_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ TEST_ABS(x86_avx2, s32, 8);

TEST_FLOOR_LOG2(x86_avx2, f32, 8);
TEST_FLOOR_LOG2(x86_avx2, f64, 4);
TEST_EXP2_ROUND(x86_avx2, f32, 8);
TEST_EXP2_ROUND(x86_avx2, f64, 4);

TEST_HORIZONTAL_MIN(x86_avx2, u8, 32);
TEST_HORIZONTAL_MIN(x86_avx2, s8, 32);
Expand Down
2 changes: 2 additions & 0 deletions ynnpack/base/simd/test/x86_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ TEST_FLOOR_LOG2(x86_avx512, f32, 4);
TEST_FLOOR_LOG2(x86_avx512, f64, 8);
TEST_FLOOR_LOG2(x86_avx512, f64, 4);
TEST_FLOOR_LOG2(x86_avx512, f64, 2);
TEST_EXP2_ROUND(x86_avx512, f32, 16);
TEST_EXP2_ROUND(x86_avx512, f64, 8);

TEST_FLOOR(x86_avx512, f32, 16);
TEST_FLOOR(x86_avx512, f64, 8);
Expand Down
2 changes: 2 additions & 0 deletions ynnpack/base/simd/test/x86_sse2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ TEST_ABS(x86_sse2, f32, 4);

TEST_FLOOR_LOG2(x86_sse2, f32, 4);
TEST_FLOOR_LOG2(x86_sse2, f64, 2);
TEST_EXP2_ROUND(x86_sse2, f32, 4);
TEST_EXP2_ROUND(x86_sse2, f64, 2);

TEST_HORIZONTAL_MIN(x86_sse2, u8, 16);
TEST_HORIZONTAL_MIN(x86_sse2, s16, 8);
Expand Down
9 changes: 8 additions & 1 deletion ynnpack/base/simd/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ template <typename T, size_t N>
vec<T, N> floor(vec<T, N> a);
template <typename T, size_t N>
vec<T, N> floor_log2(vec<T, N> a);
// This function is permitted to use any rounding mode.
template <typename T, size_t N>
vec<T, N> exp2_round(vec<T, N> a);
template <typename T, size_t N>
vec<T, N> ceil(vec<T, N> a);
template <typename T, size_t N>
Expand Down Expand Up @@ -317,7 +320,11 @@ YNN_ALWAYS_INLINE vec<T, 1> sub_sat(vec<T, 1> a, vec<T, 1> b) {
}
template <typename T>
YNN_ALWAYS_INLINE vec<T, 1> floor_log2(vec<T, 1> a) {
return vec<T, 1>{floor_log2(a.v)};
return vec<T, 1>{ynn::floor_log2(a.v)};
}
template <typename T>
YNN_ALWAYS_INLINE vec<T, 1> exp2_round(vec<T, 1> a) {
return vec<T, 1>{ynn::exp2_round(a.v)};
}

template <typename To, typename From>
Expand Down
38 changes: 38 additions & 0 deletions ynnpack/base/simd/wasm_simd128.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,44 @@ YNN_ALWAYS_INLINE f32x4 round(f32x4 a) {
}
YNN_ALWAYS_INLINE f32x4 sqrt(f32x4 a) { return f32x4{wasm_f32x4_sqrt(a.v)}; }

YNN_ALWAYS_INLINE f32x4 floor_log2(f32x4 a) {
const v128_t sign_mask = wasm_f32x4_splat(-0.0f);
const v128_t is_zero = wasm_f32x4_eq(a.v, wasm_f32x4_splat(0.0f));
a.v = wasm_v128_or(wasm_v128_and(is_zero, sign_mask), a.v);

const v128_t sign_and_exp_mask = wasm_i32x4_splat(0xFF800000);
v128_t exp = wasm_v128_and(a.v, sign_and_exp_mask);

const v128_t infinity =
wasm_f32x4_splat(std::numeric_limits<float>::infinity());
const v128_t is_inf = wasm_f32x4_eq(a.v, infinity);

exp = wasm_i32x4_shr(exp, 8);

const v128_t bias_256 = wasm_f32x4_splat(256.0f);
const v128_t bias_383 = wasm_f32x4_splat(383.0f);
const v128_t res = wasm_f32x4_sub(wasm_v128_or(bias_256, exp), bias_383);
return f32x4{wasm_v128_bitselect(infinity, res, is_inf)};
}

YNN_ALWAYS_INLINE f64x2 floor_log2(f64x2 a) {
return f64x2{wasm_f64x2_make(
ynn::floor_log2(wasm_f64x2_extract_lane(a.v, 0)),
ynn::floor_log2(wasm_f64x2_extract_lane(a.v, 1)))};
}

YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) {
const v128_t magic = wasm_f32x4_splat(127.0f + static_cast<float>(1 << 23));
const v128_t res_bits = wasm_f32x4_add(a.v, magic);
return f32x4{wasm_i32x4_shl(res_bits, 23)};
}

YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) {
return f64x2{wasm_f64x2_make(
ynn::exp2_round(wasm_f64x2_extract_lane(a.v, 0)),
ynn::exp2_round(wasm_f64x2_extract_lane(a.v, 1)))};
}

YNN_ALWAYS_INLINE s16x16 cast(s8x16 a, int16_t) {
return {s16x8{wasm_i16x8_extend_low_i8x16(a.v)},
s16x8{wasm_i16x8_extend_high_i8x16(a.v)}};
Expand Down
13 changes: 13 additions & 0 deletions ynnpack/base/simd/x86_avx2_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,19 @@ YNN_ALWAYS_INLINE f32x8 cast(bf16x8 a, float) {
16))};
}

YNN_ALWAYS_INLINE f32x8 exp2_round(f32x8 a) {
const __m256 magic = _mm256_set1_ps(127.0f + static_cast<float>(1 << 23));
const __m256 res_bits = _mm256_add_ps(a.v, magic);
return f32x8{_mm256_castsi256_ps(
_mm256_slli_epi32(_mm256_castps_si256(res_bits), 23))};
}
YNN_ALWAYS_INLINE f64x4 exp2_round(f64x4 a) {
const __m256d magic = _mm256_set1_pd(1023.0 + static_cast<double>(1ll << 52));
const __m256d res_bits = _mm256_add_pd(a.v, magic);
return f64x4{_mm256_castsi256_pd(
_mm256_slli_epi64(_mm256_castpd_si256(res_bits), 52))};
}

} // namespace simd

} // namespace ynn
Expand Down
13 changes: 13 additions & 0 deletions ynnpack/base/simd/x86_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,19 @@ YNN_ALWAYS_INLINE f64x8 floor_log2(f64x8 a) {
negative, res, _mm512_set1_pd(std::numeric_limits<double>::quiet_NaN()))};
}

YNN_ALWAYS_INLINE f32x16 exp2_round(f32x16 a) {
const __m512 magic = _mm512_set1_ps(127.0f + static_cast<float>(1 << 23));
const __m512 res_bits = _mm512_add_ps(a.v, magic);
return f32x16{_mm512_castsi512_ps(
_mm512_slli_epi32(_mm512_castps_si512(res_bits), 23))};
}
YNN_ALWAYS_INLINE f64x8 exp2_round(f64x8 a) {
const __m512d magic = _mm512_set1_pd(1023.0 + static_cast<double>(1ll << 52));
const __m512d res_bits = _mm512_add_pd(a.v, magic);
return f64x8{_mm512_castsi512_pd(
_mm512_slli_epi64(_mm512_castpd_si512(res_bits), 52))};
}

template <int Index>
YNN_ALWAYS_INLINE s32x4 extract(s32x16 x, decltype(s32x4::N)) {
return s32x4{_mm512_extracti32x4_epi32(x.v, Index)};
Expand Down
13 changes: 13 additions & 0 deletions ynnpack/base/simd/x86_sse2_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,19 @@ YNN_ALWAYS_INLINE f64x2 abs(f64x2 a) {
return f64x2{_mm_andnot_pd(_mm_set1_pd(-0.0), a.v)};
}

YNN_ALWAYS_INLINE f32x4 exp2_round(f32x4 a) {
const __m128 magic = _mm_set1_ps(127.0 + static_cast<float>(1 << 23));
const __m128 res_bits = _mm_add_ps(a.v, magic);
return f32x4{
_mm_castsi128_ps(_mm_slli_epi32(_mm_castps_si128(res_bits), 23))};
}
YNN_ALWAYS_INLINE f64x2 exp2_round(f64x2 a) {
const __m128d magic = _mm_set1_pd(1023.0 + static_cast<double>(1ll << 52));
const __m128d res_bits = _mm_add_pd(a.v, magic);
return f64x2{
_mm_castsi128_pd(_mm_slli_epi64(_mm_castpd_si128(res_bits), 52))};
}

YNN_ALWAYS_INLINE double horizontal_max(f64x2 a) {
return _mm_cvtsd_f64(_mm_max_sd(a.v, _mm_shuffle_pd(a.v, a.v, 1)));
}
Expand Down
6 changes: 6 additions & 0 deletions ynnpack/kernels/elementwise/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,11 @@ def floor_log2(value):
return Op(value.ty, "floor_log2", [value])


@intrinsic
def exp2_round(value):
return Op(value.ty, "exp2_round", [value])


@intrinsic
def sqrt(value):
return Op(value.ty, "sqrt", [value])
Expand Down Expand Up @@ -976,6 +981,7 @@ def __init__(self):
"round",
"floor",
"floor_log2",
"exp2_round",
"ceil",
"sqrt",
"reinterpret_cast",
Expand Down
Loading
Loading