diff --git a/bench/q8gemm.cc b/bench/q8gemm.cc index f06d9e6..8800efd 100644 --- a/bench/q8gemm.cc +++ b/bench/q8gemm.cc @@ -299,6 +299,105 @@ class Q8GEMM_XZP : public Q8GEMM { qnnp_q31_requantization_params requantizationParams_; }; +class Q8GEMM_PER_CHANNEL : public Q8GEMM { + public: + inline Q8GEMM_PER_CHANNEL(uint32_t mr, uint32_t nr, uint32_t np, uint32_t kr) : Q8GEMM(mr, nr, np, kr) {} + virtual void SetUp(const benchmark::State&) override + { + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + a_.resize(mc() * kc()); + std::generate(a_.begin(), a_.end(), std::ref(u8rng)); + k_.resize(nc() * kc()); + std::generate(k_.begin(), k_.end(), std::ref(u8rng)); + b_.resize(mc()); + std::generate(b_.begin(), b_.end(), std::ref(s32rng)); + w_.resize(kcStride() * ncStride() + ncStride() * sizeof(int32_t) / sizeof(uint8_t)); + const uint8_t kernel_zero_point_center = 127; + kernelZeroPointPerChannel_.resize(nr()); + requantizationScalePerChannel_.resize(nr()); + multiplierPerChannel_.resize(nr()); + rightShiftPerChannel_.resize(nr()); + const float scale_min = 0.5f; + const float scale_max = 0.99999f; + for (size_t i = 0; i < nr(); ++i) { + kernelZeroPointPerChannel_[i] = + static_cast(std::min(255, std::max(0, kernel_zero_point_center + (int)(i - nr()/2)))); + requantizationScalePerChannel_[i] = scale_min + i * (scale_max - scale_min) / nr(); + } + std::fill(w_.begin(), w_.end(), kernel_zero_point_center); + pack_q8gemm_w_per_channel( + nc(), kc(), + nr(), np(), kr(), + 127, kernelZeroPointPerChannel_.data(), + k(), b(), w()); + c_.resize(mc() * nc()); + std::fill(c_.begin(), c_.end(), 0xA5); + quantizationParams_ = + qnnp_compute_conv_quantization_params_per_channel( + 127, nr(), kernelZeroPointPerChannel_.data(), + requantizationScalePerChannel_.data(), multiplierPerChannel_.data(), rightShiftPerChannel_.data(), + 127, 1, 254); + } + + virtual void TearDown(benchmark::State& state) override + { + state.SetItemsProcessed(uint64_t(state.iterations()) * 2 * mc() * nc() * kc()); + a_.clear(); + k_.clear(); + b_.clear(); + w_.clear(); + c_.clear(); + kernelZeroPointPerChannel_.clear(); + kernelAndInputScalePerChannel_.clear(); + requantizationScalePerChannel_.clear(); + multiplierPerChannel_.clear(); + rightShiftPerChannel_.clear(); + } + + protected: + std::vector kernelZeroPointPerChannel_; + std::vector kernelAndInputScalePerChannel_; + std::vector requantizationScalePerChannel_; + std::vector multiplierPerChannel_; + std::vector rightShiftPerChannel_; +}; + +template +class Q8GEMM_PER_CHANNEL_L1 : public Q8GEMM_PER_CHANNEL { + public: + inline Q8GEMM_PER_CHANNEL_L1() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR) + { + cpuinfo_initialize(); + const size_t l1d_size = cpuinfo_get_l1d_cache(0)->size; + const size_t l1d_reserve = 512; + kc_ = ((l1d_size - l1d_reserve) / sizeof(uint8_t) - mr() * nr()) / (mr() + nr()); + if (kr() != 1) { + kc_ = kc_ / kr() * kr(); + } else { + kc_ = kc_ / nr() * nr(); + } + } +}; + +template +class Q8GEMM_PER_CHANNEL_Op : public Q8GEMM_PER_CHANNEL { + public: + inline Q8GEMM_PER_CHANNEL_Op() : Q8GEMM_PER_CHANNEL(MR, NR, NP, KR) {} + + virtual void SetUp(const benchmark::State& state) override + { + mc_ = state.range(0); + nc_ = state.range(1); + kc_ = state.range(2); + + Q8GEMM_PER_CHANNEL::SetUp(state); + } +}; + template class Q8GEMM_XZP_L1 : public Q8GEMM_XZP { public: @@ -647,6 +746,40 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(MobileNetV1GemmArgumen BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(SqueezeNetV10GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__aarch32_neon)->Apply(GemmArguments); +BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + q8gemm_ukernel_4x8__aarch32_neon_per_channel( + mr(), nr(), kc(), + a(), kc() * sizeof(uint8_t), + w(), + c(), mr() * sizeof(uint8_t), + quantizationParams(), 0); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + q8gemm_ukernel_4x8__aarch32_neon_per_channel( + mrr, nrr, kc(), + a() + m * kc(), kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, nc() * sizeof(uint8_t), + quantizationParams(), 0); + } + } + } +} +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__aarch32_neon_per_channel)->Apply(GemmArguments); + BENCHMARK_TEMPLATE_F(Q8GEMM_XZP_L1, 4x8c2__aarch32_neon, 4, 8, 8, 2)(benchmark::State& state) { for (auto _ : state) { @@ -770,6 +903,41 @@ BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(MobileNetV1GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(SqueezeNetV10GemmArguments); BENCHMARK_REGISTER_F(Q8GEMM_Op, 4x8__neon)->Apply(GemmArguments); +BENCHMARK_TEMPLATE_F(Q8GEMM_PER_CHANNEL_L1, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + q8gemm_ukernel_4x8__neon_per_channel( + mr(), nr(), kc(), + a(), kc() * sizeof(uint8_t), + w(), + c(), mr() * sizeof(uint8_t), + quantizationParams(), 0); + } +} + +BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel, 4, 8, 8, 1)(benchmark::State& state) +{ + for (auto _ : state) { + for (uint32_t m = 0; m < mc(); m += mr()) { + const uint32_t mrr = min(mc() - m, mr()); + for (uint32_t n = 0; n < nc(); n += nr()) { + const uint32_t nrr = min(nc() - n, nr()); + q8gemm_ukernel_4x8__neon_per_channel( + mrr, nrr, kc(), + a() + m * kc(), kc() * sizeof(uint8_t), + w() + n * (kcStride() * sizeof(uint8_t) + sizeof(int32_t)), + c() + m * nc() + n, nc() * sizeof(uint8_t), + quantizationParams(), 0); + } + } + } +} + +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(ShuffleNetV1G1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(MobileNetV1GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(SqueezeNetV10GemmArguments); +BENCHMARK_REGISTER_F(Q8GEMM_PER_CHANNEL_Op, 4x8__neon_per_channel)->Apply(GemmArguments); + BENCHMARK_TEMPLATE_DEFINE_F(Q8GEMM_Op, 8x8__neon, 8, 8, 8, 1)(benchmark::State& state) { for (auto _ : state) { diff --git a/configure.py b/configure.py index 8c9d1ae..b6adc24 100755 --- a/configure.py +++ b/configure.py @@ -107,6 +107,7 @@ def main(args): build.cc("q8gavgpool/up8xm-neon.c"), build.cc("q8gemm/4x-sumrows-neon.c"), build.cc("q8gemm/4x8-neon.c"), + build.cc("q8gemm/4x8-neon_per_channel.c"), build.cc("q8gemm/4x8c2-xzp-neon.c"), build.cc("q8gemm/6x4-neon.c"), build.cc("q8gemm/8x8-neon.c"), @@ -128,6 +129,7 @@ def main(args): build.cc("q8conv/4x8-aarch32-neon.S"), build.cc("q8dwconv/up8x9-aarch32-neon.S"), build.cc("q8gemm/4x8-aarch32-neon.S"), + build.cc("q8gemm/4x8-aarch32-neon-per-channel.S"), build.cc("q8gemm/4x8c2-xzp-aarch32-neon.S"), ] if build.target.is_arm64: diff --git a/src/q8gemm/4x8-aarch32-neon-per-channel.S b/src/q8gemm/4x8-aarch32-neon-per-channel.S new file mode 100644 index 0000000..a88f368 --- /dev/null +++ b/src/q8gemm/4x8-aarch32-neon-per-channel.S @@ -0,0 +1,819 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +.syntax unified + +# void q8gemm_ukernel_4x8__aarch32_neon_per_channel( +# size_t mr, +# size_t nr, +# size_t k, +# const uint8_t*restrict a, +# size_t a_stride, +# const void*restrict w, +# uint8_t*restrict c, +# size_t c_stride, +# const union qnnp_conv_quantization_params quantization_params[restrict static 1], +# size_t kernel_quantization_params_offset) +BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel + .arm +#ifndef __APPLE__ + .arch armv7-a + .fpu neon +#endif + # Load w + # - ip = w + LDR ip, [sp, 4] + PUSH {r4-r8} + + VPUSH {d8-d15} + # Load quantization params + # - r7 = quantization_params + LDR r7, [sp, 100] + + # load kernel_quantization_params_offset + LDR r5, [sp, 104] + + # Load bias0123, bias4567 + VLDM ip!, {d16-d19} + + # Load a_stride + # - r6 = a_stride + LDR r6, [sp, 84] + CMP r0, 2 + + # a1 = a0 + a_stride + ADD r4, r3, r6 + + # Load b_zero_point from kernel_zero_point_v: + # - d15 = b_zero_point + MOV r8, 16 + ADD r8, r7, r8 + LDR r8, [r8, r5] + VLD1.8 {d15}, [r8] + MOVLO r4, r3 + + # Move kernel_quantization_params_offset to r8 to use later + MOV r8, r5 + + ADD r7, r7, 4 + ADD r5, r4, r6 + + # q10 := vacc1x0123 + VMOV.I32 q10, q8 + MOVLS r5, r4 + # q11 := vacc1x4567 + VMOV.I32 q11, q9 + ADD r6, r5, r6 + # q12 := vacc2x0123 + VMOV.I32 q12, q8 + CMP r0, 4 + # q13 := vacc2x4567 + VMOV.I32 q13, q9 + MOVNE r6, r5 + # q14 := vacc3x0123 + VMOV.I32 q14, q8 + SUBS r2, r2, 8 + # q15 := vacc3x4567 + VMOV.I32 q15, q9 + + BLO 1f + + .p2align 5 +0: + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3]! + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4]! + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5]! + + # q0 = va0 = a0 + VMOVL.U8 q0, d1 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6]! + + # q1 = va1 = a1 + VMOVL.U8 q1, d3 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VMOVL.U8 q2, d5 + # q3 = va3 = a3 + VMOVL.U8 q3, d7 + + ### Channel 0 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + ### Channel 1 ### + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + + # Load b0-b7 (channel 3) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 3) + # - d11 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + + # Load b0-b7 (channel 4) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + + # Load b0-b7 (channel 5) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 5) + # - d9 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + + # Load b0-b7 (channel 7) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 7) + # - d11 = vb4567 (channel 7) + VSUBL.U8 q5, d11, d15 + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + ### Channel 8 ### + SUBS r2, r2, 8 + + # vacc0x0123 += vb0123 * va0[7] + VMLAL.S16 q8, d10, d1[3] + # vacc0x4567 += vb4567 * va0[7] + VMLAL.S16 q9, d11, d1[3] + + # vacc1x0123 += vb0123 * va1[7] + VMLAL.S16 q10, d10, d3[3] + # vacc1x4567 += vb4567 * va1[7] + VMLAL.S16 q11, d11, d3[3] + + # vacc2x0123 += vb0123 * va2[7] + VMLAL.S16 q12, d10, d5[3] + # vacc2x4567 += vb4567 * va2[7] + VMLAL.S16 q13, d11, d5[3] + + # vacc3x0123 += vb0123 * va3[7] + VMLAL.S16 q14, d10, d7[3] + # vacc3x4567 += vb4567 * va3[7] + VMLAL.S16 q15, d11, d7[3] + + BHS 0b + +1: + CMP r2, -8 + BEQ 2f + + # Adjust a0, a1, a2, a3 + ADD r3, r2 + ADD r4, r2 + ADD r5, r2 + ADD r6, r2 + + # a_shift = 8 * k - 64 + LSL r2, r2, 3 + VDUP.32 d13, r2 + + # Load a0 + # - d1 = a0 + VLD1.8 {d1}, [r3] + + # Load a1 + # - d3 = a1 + VLD1.8 {d3}, [r4] + + # Load b0-b7 (channel 0) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # Load a2 + # - d5 = a2 + VLD1.8 {d5}, [r5] + + # q0 = va0 = a0 + VSHL.U64 d1, d1, d13 + VMOVL.U8 q0, d1 + + # Load a3 + # - d7 = a3 + VLD1.8 {d7}, [r6] + + # q1 = va1 = a1 + VSHL.U64 d3, d3, d13 + VMOVL.U8 q1, d3 + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 0) + # - d9 = vb4567 (channel 0) + VSUBL.U8 q4, d9, d15 + + # q2 = va2 = a2 + VSHL.U64 d5, d5, d13 + VMOVL.U8 q2, d5 + # q3 = va3 = a3 + VSHL.U64 d7, d7, d13 + VMOVL.U8 q3, d7 + + ### Channel 0 ### + + # vacc0x0123 += vb0123 * va0[0] + VMLAL.S16 q8, d8, d0[0] + # vacc0x4567 += vb4567 * va0[0] + VMLAL.S16 q9, d9, d0[0] + + # vacc1x0123 += vb0123 * va1[0] + VMLAL.S16 q10, d8, d2[0] + # vacc1x4567 += vb4567 * va1[0] + VMLAL.S16 q11, d9, d2[0] + + # vacc2x0123 += vb0123 * va2[0] + VMLAL.S16 q12, d8, d4[0] + # vacc2x4567 += vb4567 * va2[0] + VMLAL.S16 q13, d9, d4[0] + + # vacc3x0123 += vb0123 * va3[0] + VMLAL.S16 q14, d8, d6[0] + # vacc3x4567 += vb4567 * va3[0] + VMLAL.S16 q15, d9, d6[0] + + CMP r2, -48 + BLO 2f + + ### Channel 1 ### + + # Load b0-b7 (channel 1) + # - d11 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 1) + # - d11 = vb4567 (channel 1) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[1] + VMLAL.S16 q8, d10, d0[1] + # vacc0x4567 += vb4567 * va0[1] + VMLAL.S16 q9, d11, d0[1] + + # vacc1x0123 += vb0123 * va1[1] + VMLAL.S16 q10, d10, d2[1] + # vacc1x4567 += vb4567 * va1[1] + VMLAL.S16 q11, d11, d2[1] + + # vacc2x0123 += vb0123 * va2[1] + VMLAL.S16 q12, d10, d4[1] + # vacc2x4567 += vb4567 * va2[1] + VMLAL.S16 q13, d11, d4[1] + + # vacc3x0123 += vb0123 * va3[1] + VMLAL.S16 q14, d10, d6[1] + # vacc3x4567 += vb4567 * va3[1] + VMLAL.S16 q15, d11, d6[1] + + ### Channel 2 ### + BLS 2f + + # Load b0-b7 (channel 2) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 2) + # - d9 = vb4567 (channel 2) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[2] + VMLAL.S16 q8, d8, d0[2] + # vacc0x4567 += vb4567 * va0[2] + VMLAL.S16 q9, d9, d0[2] + + # vacc1x0123 += vb0123 * va1[2] + VMLAL.S16 q10, d8, d2[2] + # vacc1x4567 += vb4567 * va1[2] + VMLAL.S16 q11, d9, d2[2] + + # vacc2x0123 += vb0123 * va2[2] + VMLAL.S16 q12, d8, d4[2] + # vacc2x4567 += vb4567 * va2[2] + VMLAL.S16 q13, d9, d4[2] + + # vacc3x0123 += vb0123 * va3[2] + VMLAL.S16 q14, d8, d6[2] + # vacc3x4567 += vb4567 * va3[2] + VMLAL.S16 q15, d9, d6[2] + + ### Channel 3 ### + CMP r2, -32 + BLO 2f + + # Load b0-b7 (channel 3) + # - d9 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 3) + # - d9 = vb4567 (channel 3) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[3] + VMLAL.S16 q8, d10, d0[3] + # vacc0x4567 += vb4567 * va0[3] + VMLAL.S16 q9, d11, d0[3] + + # vacc1x0123 += vb0123 * va1[3] + VMLAL.S16 q10, d10, d2[3] + # vacc1x4567 += vb4567 * va1[3] + VMLAL.S16 q11, d11, d2[3] + + # vacc2x0123 += vb0123 * va2[3] + VMLAL.S16 q12, d10, d4[3] + # vacc2x4567 += vb4567 * va2[3] + VMLAL.S16 q13, d11, d4[3] + + # vacc3x0123 += vb0123 * va3[3] + VMLAL.S16 q14, d10, d6[3] + # vacc3x4567 += vb4567 * va3[3] + VMLAL.S16 q15, d11, d6[3] + + ### Channel 4 ### + BLS 2f + + # Load b0-b7 (channel 4) + # - d11 = b0-b7 + VLD1.8 {d9}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 4) + # - d11 = vb4567 (channel 4) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[4] + VMLAL.S16 q8, d8, d1[0] + # vacc0x4567 += vb4567 * va0[4] + VMLAL.S16 q9, d9, d1[0] + + # vacc1x0123 += vb0123 * va1[4] + VMLAL.S16 q10, d8, d3[0] + # vacc1x4567 += vb4567 * va1[4] + VMLAL.S16 q11, d9, d3[0] + + # vacc2x0123 += vb0123 * va2[4] + VMLAL.S16 q12, d8, d5[0] + # vacc2x4567 += vb4567 * va2[4] + VMLAL.S16 q13, d9, d5[0] + + # vacc3x0123 += vb0123 * va3[4] + VMLAL.S16 q14, d8, d7[0] + # vacc3x4567 += vb4567 * va3[4] + VMLAL.S16 q15, d9, d7[0] + + ### Channel 5 ### + CMP r2, -16 + BLO 2f + + # Load b0-b7 (channel 5) + # - d13 = b0-b7 + VLD1.8 {d11}, [ip:64]! + + # q5 = b0:7 - b_zero_point + # - d10 = vb0123 (channel 5) + # - d11 = vb4567 (channel 5) + VSUBL.U8 q5, d11, d15 + + # vacc0x0123 += vb0123 * va0[5] + VMLAL.S16 q8, d10, d1[1] + # vacc0x4567 += vb4567 * va0[5] + VMLAL.S16 q9, d11, d1[1] + + # vacc1x0123 += vb0123 * va1[5] + VMLAL.S16 q10, d10, d3[1] + # vacc1x4567 += vb4567 * va1[5] + VMLAL.S16 q11, d11, d3[1] + + # vacc2x0123 += vb0123 * va2[5] + VMLAL.S16 q12, d10, d5[1] + # vacc2x4567 += vb4567 * va2[5] + VMLAL.S16 q13, d11, d5[1] + + # vacc3x0123 += vb0123 * va3[5] + VMLAL.S16 q14, d10, d7[1] + # vacc3x4567 += vb4567 * va3[5] + VMLAL.S16 q15, d11, d7[1] + + ### Channel 6 ### + BLS 2f + + # Load b0-b7 (channel 6) + # - d9 = b0-b7 + VLD1.8 {d9}, [ip:64] + + # q4 = b0:7 - b_zero_point + # - d8 = vb0123 (channel 6) + # - d9 = vb4567 (channel 6) + VSUBL.U8 q4, d9, d15 + + # vacc0x0123 += vb0123 * va0[6] + VMLAL.S16 q8, d8, d1[2] + # vacc0x4567 += vb4567 * va0[6] + VMLAL.S16 q9, d9, d1[2] + + # vacc1x0123 += vb0123 * va1[6] + VMLAL.S16 q10, d8, d3[2] + # vacc1x4567 += vb4567 * va1[6] + VMLAL.S16 q11, d9, d3[2] + + # vacc2x0123 += vb0123 * va2[6] + VMLAL.S16 q12, d8, d5[2] + + # vacc2x4567 += vb4567 * va2[6] + VMLAL.S16 q13, d9, d5[2] + + # vacc3x0123 += vb0123 * va3[6] + VMLAL.S16 q14, d8, d7[2] + # vacc3x4567 += vb4567 * va3[6] + VMLAL.S16 q15, d9, d7[2] + + .p2align 4 +2: + # Load multiplier: + # - d12 = vmultiplier ( vmultiplier0x0123 ) + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] + VLD1.32 {d12, d13}, [r8]! + SUB r7, 12 + + VQRDMULH.S32 q8, q8, q6 + VQRDMULH.S32 q10, q10, q6 + VQRDMULH.S32 q12, q12, q6 + VQRDMULH.S32 q14, q14, q6 + + VLD1.32 {d12, d13}, [r8] + + VQRDMULH.S32 q9, q9, q6 + VQRDMULH.S32 q11, q11, q6 + VQRDMULH.S32 q13, q13, q6 + VQRDMULH.S32 q15, q15, q6 + + # Load right_shift + # - q4 = d8:d9 = vright_shift_0x0123 + ADD r7, 16 + LDR r8, [sp, 104] + LDR r8, [r7, r8] + VLD1.32 {d8, d9}, [r8]! + SUB r7, 12 + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask_0x0123 + VCEQ.S32 q5, q4, 0 + + VBIC q0, q8, q5 + VBIC q1, q10, q5 + VBIC q2, q12, q5 + VBIC q3, q14, q5 + + VSRA.S32 q8, q0, 31 + VSRA.S32 q10, q1, 31 + VSRA.S32 q12, q2, 31 + VSRA.S32 q14, q3, 31 + + VRSHL.S32 q8, q8, q4 + VRSHL.S32 q10, q10, q4 + VRSHL.S32 q12, q12, q4 + VRSHL.S32 q14, q14, q4 + + # - q4 = d8:d9 = vright_shift_0x4567 + VLD1.32 {d8, d9}, [r8] + + # Compute vzero_shift_mask + # - q5 = vzero_shift_mask_0x4567 + VCEQ.S32 q5, q4, 0 + + VBIC q0, q9, q5 + VBIC q1, q11, q5 + VBIC q2, q13, q5 + VBIC q3, q15, q5 + + VSRA.S32 q9, q0, 31 + VSRA.S32 q11, q1, 31 + VSRA.S32 q13, q2, 31 + VSRA.S32 q15, q3, 31 + + VRSHL.S32 q9, q9, q4 + VRSHL.S32 q11, q11, q4 + VRSHL.S32 q13, q13, q4 + VRSHL.S32 q15, q15, q4 + + # Load output_zero_point + # - q7 = d14:d15 = voutput_zero_point + VLD1.16 {d14[], d15[]}, [r7]! + + # Load max: + # - q5 = d10:d11 = vmax + VLD1.8 {d10[], d11[]}, [r7]! + + # Load c, c_stride: + # - r2 = c + # - r2 = c_stride + LDRD r2, r3, [sp, 92] + + VQMOVN.S32 d16, q8 + VQMOVN.S32 d17, q9 + VQMOVN.S32 d18, q10 + VQMOVN.S32 d19, q11 + VQMOVN.S32 d20, q12 + VQMOVN.S32 d21, q13 + VQMOVN.S32 d22, q14 + VQMOVN.S32 d23, q15 + + # Load min: + # - q4 = q8:q9 = vmin + VLD1.8 {d8[], d9[]}, [r7]! + ADD r4, r2, r3 + + VQADD.S16 q8, q8, q7 + VQADD.S16 q9, q9, q7 + CMP r0, 2 + VQADD.S16 q10, q10, q7 + VQADD.S16 q11, q11, q7 + MOVLO r4, r2 + + VQMOVUN.S16 d16, q8 + VQMOVUN.S16 d17, q9 + ADD r5, r4, r3 + VQMOVUN.S16 d18, q10 + VQMOVUN.S16 d19, q11 + MOVLS r5, r4 + + VMIN.U8 q8, q8, q5 + CMP r0, 4 + VMIN.U8 q9, q9, q5 + ADD r3, r5, r3 + + VMAX.U8 q8, q8, q4 + MOVNE r3, r5 + CMP r1, 8 + VMAX.U8 q9, q9, q4 + + BNE 4f + + VST1.8 {d16}, [r2] + VST1.8 {d17}, [r4] + VST1.8 {d18}, [r5] + VST1.8 {d19}, [r3] + + VPOP {d8-d15} + POP {r4-r8} + BX lr + + .p2align 3 +4: + CMP r1, 4 + BLO 5f + + VST1.32 {d16[0]}, [r2]! + VST1.32 {d17[0]}, [r4]! + VST1.32 {d18[0]}, [r5]! + VST1.32 {d19[0]}, [r3]! + + SUB r1, 4 + VEXT.8 q8, q8, q8, 4 + VEXT.8 q9, q9, q9, 4 + +5: + CMP r1, 2 + BLO 6f + + VST1.16 {d16[0]}, [r2]! + VST1.16 {d17[0]}, [r4]! + VST1.16 {d18[0]}, [r5]! + VST1.16 {d19[0]}, [r3]! + + SUB r1, 2 + VEXT.8 q8, q8, q8, 2 + VEXT.8 q9, q9, q9, 2 + +6: + TEQ r1, 0 + BEQ 7f + + VST1.8 {d16[0]}, [r2] + VST1.8 {d17[0]}, [r4] + VST1.8 {d18[0]}, [r5] + VST1.8 {d19[0]}, [r3] + +7: + VPOP {d8-d15} + POP {r4-r8} + BX lr +END_FUNCTION q8gemm_ukernel_4x8__aarch32_neon_per_channel + +#ifdef __ELF__ +.section ".note.GNU-stack","",%progbits +#endif diff --git a/src/q8gemm/4x8-aarch32-neon.S b/src/q8gemm/4x8-aarch32-neon.S index a8c1021..1cbbc1a 100644 --- a/src/q8gemm/4x8-aarch32-neon.S +++ b/src/q8gemm/4x8-aarch32-neon.S @@ -710,7 +710,7 @@ BEGIN_FUNCTION q8gemm_ukernel_4x8__aarch32_neon VQADD.S16 q9, q9, q7 CMP r0, 2 VQADD.S16 q10, q10, q7 - VQADD.S16 q11, q11, q7 + VQADD.S16 q11, q11, q7 MOVLO r4, r2 VQMOVUN.S16 d16, q8 diff --git a/src/q8gemm/4x8-neon_per_channel.c b/src/q8gemm/4x8-neon_per_channel.c new file mode 100644 index 0000000..c7af2b5 --- /dev/null +++ b/src/q8gemm/4x8-neon_per_channel.c @@ -0,0 +1,368 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + + +void q8gemm_ukernel_4x8__neon_per_channel( + size_t mr, + size_t nr, + size_t k, + const uint8_t* restrict a, + size_t a_stride, + const void* restrict w, + uint8_t* restrict c, + size_t c_stride, + const union qnnp_conv_quantization_params quantization_params[restrict static 1], + size_t kernel_quantization_params_offset) +{ + int32x4_t vacc0x0123 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc0x4567 = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 16); + int32x4_t vacc1x0123 = vacc0x0123; + int32x4_t vacc1x4567 = vacc0x4567; + int32x4_t vacc2x0123 = vacc0x0123; + int32x4_t vacc2x4567 = vacc0x4567; + int32x4_t vacc3x0123 = vacc0x0123; + int32x4_t vacc3x4567 = vacc0x4567; + + const uint8_t* a0 = a; + const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride); + if (mr < 2) { + a1 = a0; + } + const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride); + if (mr <= 2) { + a2 = a1; + } + const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride); + if (mr != 4) { + a3 = a2; + } + + const uint8x8_t vb_zero_point = vld1_u8((const uint8_t*) &quantization_params->neon.kernel_zero_point_v[kernel_quantization_params_offset]); + for (; k >= 8; k -= 8) { + const uint8x8_t va0 = vld1_u8(a0); a0 += 8; + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vld1_u8(a1); a1 += 8; + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vld1_u8(a2); a2 += 8; + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vld1_u8(a3); a3 += 8; + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + + const uint8x8_t vb01234567c7 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c7 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c7, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c7), vget_high_s16(vxa3), 3); + } + if (k != 0) { + const size_t a_predecrement = 8 - k; + const int64x1_t va_shift = vmov_n_s64(-8 * a_predecrement); + const uint8x8_t va0 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a0 - a_predecrement)), va_shift)); + const int16x8_t vxa0 = vreinterpretq_s16_u16(vmovl_u8(va0)); + const uint8x8_t va1 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a1 - a_predecrement)), va_shift)); + const int16x8_t vxa1 = vreinterpretq_s16_u16(vmovl_u8(va1)); + const uint8x8_t va2 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a2 - a_predecrement)), va_shift)); + const int16x8_t vxa2 = vreinterpretq_s16_u16(vmovl_u8(va2)); + const uint8x8_t va3 = vreinterpret_u8_u64(vshl_u64(vreinterpret_u64_u8(vld1_u8(a3 - a_predecrement)), va_shift)); + const int16x8_t vxa3 = vreinterpretq_s16_u16(vmovl_u8(va3)); + + const uint8x8_t vb01234567c0 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c0 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c0, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c0), vget_low_s16(vxa3), 0); + + if (k >= 2) { + const uint8x8_t vb01234567c1 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c1 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c1, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c1), vget_low_s16(vxa3), 1); + + if (k >= 3) { + const uint8x8_t vb01234567c2 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c2 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c2, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c2), vget_low_s16(vxa3), 2); + + if (k >= 4) { + const uint8x8_t vb01234567c3 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c3 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c3, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa0), 3); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa1), 3); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa2), 3); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c3), vget_low_s16(vxa3), 3); + + if (k >= 5) { + const uint8x8_t vb01234567c4 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c4 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c4, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa0), 0); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa1), 0); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa2), 0); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c4), vget_high_s16(vxa3), 0); + + if (k >= 6) { + const uint8x8_t vb01234567c5 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c5 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c5, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa0), 1); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa1), 1); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa2), 1); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c5), vget_high_s16(vxa3), 1); + + if (k >= 7) { + const uint8x8_t vb01234567c6 = vld1_u8(w); w = (const void*) ((uintptr_t) w + 8); + const int16x8_t vxb01234567c6 = vreinterpretq_s16_u16(vsubl_u8(vb01234567c6, vb_zero_point)); + + vacc0x0123 = vmlal_lane_s16(vacc0x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc0x4567 = vmlal_lane_s16(vacc0x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa0), 2); + vacc1x0123 = vmlal_lane_s16(vacc1x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc1x4567 = vmlal_lane_s16(vacc1x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa1), 2); + vacc2x0123 = vmlal_lane_s16(vacc2x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc2x4567 = vmlal_lane_s16(vacc2x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa2), 2); + vacc3x0123 = vmlal_lane_s16(vacc3x0123, vget_low_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + vacc3x4567 = vmlal_lane_s16(vacc3x4567, vget_high_s16(vxb01234567c6), vget_high_s16(vxa3), 2); + } + } + } + } + } + } + } + + const int32x4_t vmultiplier0x0123 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset]); + const int32x4_t vmultiplier0x4567 = vld1q_s32(&quantization_params->neon.multiplier_v[kernel_quantization_params_offset + 4]); + vacc0x0123 = vqrdmulhq_s32(vacc0x0123, vmultiplier0x0123); + vacc0x4567 = vqrdmulhq_s32(vacc0x4567, vmultiplier0x4567); + vacc1x0123 = vqrdmulhq_s32(vacc1x0123, vmultiplier0x0123); + vacc1x4567 = vqrdmulhq_s32(vacc1x4567, vmultiplier0x4567); + vacc2x0123 = vqrdmulhq_s32(vacc2x0123, vmultiplier0x0123); + vacc2x4567 = vqrdmulhq_s32(vacc2x4567, vmultiplier0x4567); + vacc3x0123 = vqrdmulhq_s32(vacc3x0123, vmultiplier0x0123); + vacc3x4567 = vqrdmulhq_s32(vacc3x4567, vmultiplier0x4567); + + const int32x4_t vright_shift_0x0123 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset]); + const int32x4_t vright_shift_0x4567 = vld1q_s32(&quantization_params->neon.right_shift_v[kernel_quantization_params_offset + 4]); + const int32x4_t vzero_shift_mask_0x0123 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x0123, vmovq_n_s32(0))); + const int32x4_t vzero_shift_mask_0x4567 = vreinterpretq_s32_u32(vceqq_s32(vright_shift_0x4567, vmovq_n_s32(0))); + vacc0x0123 = vsraq_n_s32(vacc0x0123, vbicq_s32(vacc0x0123, vzero_shift_mask_0x0123), 31); + vacc0x4567 = vsraq_n_s32(vacc0x4567, vbicq_s32(vacc0x4567, vzero_shift_mask_0x4567), 31); + vacc1x0123 = vsraq_n_s32(vacc1x0123, vbicq_s32(vacc1x0123, vzero_shift_mask_0x0123), 31); + vacc1x4567 = vsraq_n_s32(vacc1x4567, vbicq_s32(vacc1x4567, vzero_shift_mask_0x4567), 31); + vacc2x0123 = vsraq_n_s32(vacc2x0123, vbicq_s32(vacc2x0123, vzero_shift_mask_0x0123), 31); + vacc2x4567 = vsraq_n_s32(vacc2x4567, vbicq_s32(vacc2x4567, vzero_shift_mask_0x4567), 31); + vacc3x0123 = vsraq_n_s32(vacc3x0123, vbicq_s32(vacc3x0123, vzero_shift_mask_0x0123), 31); + vacc3x4567 = vsraq_n_s32(vacc3x4567, vbicq_s32(vacc3x4567, vzero_shift_mask_0x4567), 31); + + vacc0x0123 = vrshlq_s32(vacc0x0123, vright_shift_0x0123); + vacc0x4567 = vrshlq_s32(vacc0x4567, vright_shift_0x4567); + vacc1x0123 = vrshlq_s32(vacc1x0123, vright_shift_0x0123); + vacc1x4567 = vrshlq_s32(vacc1x4567, vright_shift_0x4567); + vacc2x0123 = vrshlq_s32(vacc2x0123, vright_shift_0x0123); + vacc2x4567 = vrshlq_s32(vacc2x4567, vright_shift_0x4567); + vacc3x0123 = vrshlq_s32(vacc3x0123, vright_shift_0x0123); + vacc3x4567 = vrshlq_s32(vacc3x4567, vright_shift_0x4567); + + const int16x8_t voutput_zero_point = vld1q_dup_s16(&quantization_params->neon.output_zero_point); +#ifdef __aarch64__ + const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc1x01234567); + uint8x16_t vout2x01234567_3x01234567 = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc3x01234567); +#else + const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point); + const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point); + const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point); + const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point); + + uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc1x01234567)); + uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc3x01234567)); +#endif + const uint8x16_t voutput_min = vld1q_dup_u8(&quantization_params->neon.output_min); + const uint8x16_t voutput_max = vld1q_dup_u8(&quantization_params->neon.output_max); + + vout0x01234567_1x01234567 = vmaxq_u8(vout0x01234567_1x01234567, voutput_min); + vout2x01234567_3x01234567 = vmaxq_u8(vout2x01234567_3x01234567, voutput_min); + vout0x01234567_1x01234567 = vminq_u8(vout0x01234567_1x01234567, voutput_max); + vout2x01234567_3x01234567 = vminq_u8(vout2x01234567_3x01234567, voutput_max); + + uint8_t* c0 = c; + uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + c_stride); + if (mr < 2) { + c1 = c0; + } + uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + c_stride); + if (mr <= 2) { + c2 = c1; + } + uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + c_stride); + if (mr != 4) { + c3 = c2; + } + if (nr == 8) { + vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); + vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); + vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); + vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); + } else { + if (nr >= 4) { + vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4; + vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4); + nr -= 4; + } + if (nr >= 2) { + vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2; + vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2; + vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2); + vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2); + nr -= 2; + } + if (nr != 0) { + vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0); + vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8); + vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0); + vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8); + } + } +} diff --git a/src/qnnpack/pack.h b/src/qnnpack/pack.h index 836b03c..885ed10 100644 --- a/src/qnnpack/pack.h +++ b/src/qnnpack/pack.h @@ -48,6 +48,45 @@ static inline void pack_q8gemm_w( } } +static inline void pack_q8gemm_w_per_channel( + size_t nc, // num output channels + size_t kc, // num input channels + uint32_t nr, // kernel-n-block-size + uint32_t np, // packed-n + uint32_t kr, + uint8_t izp, + uint8_t* kzp, + const uint8_t* k, + const int32_t* b, + void* packed_w) +{ + for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) { + const size_t nr_block_size = min(nc - nr_block_start, nr); + int32_t* packed_b = (int32_t*) packed_w; + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { + *((int32_t*) packed_w) = b[nr_block_start + nr_block_offset] + + (int32_t) kc * (int32_t) izp * (int32_t) kzp[nr_block_start + nr_block_offset]; + packed_w = (void*) ((uintptr_t) packed_w + sizeof(int32_t)); + } + packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t)); + for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) { + const size_t kr_block_size = min(kc - kr_block_start, kr); + for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) { + int32_t ksum = 0; + for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) { + const uint8_t kv = k[(nr_block_start + nr_block_offset) * kc + (kr_block_start + kr_block_offset)]; + ksum += (int32_t) kv; + *((uint8_t*) packed_w) = kv; + packed_w = (void*) ((uintptr_t) packed_w + sizeof(uint8_t)); + } + packed_b[nr_block_offset] -= ksum * (int32_t) izp; + packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(uint8_t)); + } + packed_w = (void*) ((uintptr_t) packed_w + ((nr - nr_block_size) & (np - 1)) * kr * sizeof(uint8_t)); + } + } +} + static inline void pack_q8conv_w( size_t n, size_t ks, diff --git a/src/qnnpack/params.h b/src/qnnpack/params.h index e30e237..1d14530 100644 --- a/src/qnnpack/params.h +++ b/src/qnnpack/params.h @@ -145,6 +145,9 @@ union qnnp_conv_quantization_params { int16_t output_zero_point; uint8_t output_max; uint8_t output_min; + uint8_t* kernel_zero_point_v; + int32_t* multiplier_v; + int32_t* right_shift_v; } neon; #endif /* CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 */ #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 @@ -275,6 +278,18 @@ typedef void (*q8gemm_ukernel_function)( size_t c_stride, const union qnnp_conv_quantization_params* quantization_params); +typedef void (*q8gemm_per_channel_ukernel_function)( + size_t mr, + size_t nr, + size_t k, + const uint8_t* a, + size_t a_stride, + const void* w, + uint8_t* c, + size_t c_stride, + const union qnnp_conv_quantization_params* quantization_params, + size_t kernel_quantization_params_offset); + typedef void (*q8conv_ukernel_function)( size_t mr, size_t nr, diff --git a/src/qnnpack/q8gemm.h b/src/qnnpack/q8gemm.h index f5ac117..2bf8753 100644 --- a/src/qnnpack/q8gemm.h +++ b/src/qnnpack/q8gemm.h @@ -43,6 +43,23 @@ DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_8x8__aarch64_neon) DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_2x4c8__sse2) DECLARE_Q8GEMM_UKERNEL_FUNCTION(q8gemm_ukernel_4x4c2__sse2) +#define DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(fn_name) \ + QNNP_INTERNAL void fn_name( \ + size_t mr, \ + size_t nr, \ + size_t k, \ + const uint8_t* a, \ + size_t a_stride, \ + const void* w, \ + uint8_t* c, \ + size_t c_stride, \ + const union qnnp_conv_quantization_params* quantization_params, \ + size_t kernel_quantization_params_offset); + +DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_ukernel_4x8__neon_per_channel) + +DECLARE_Q8GEMM_PER_CHANNEL_UKERNEL_FUNCTION(q8gemm_ukernel_4x8__aarch32_neon_per_channel) + #define DECLARE_Q8GEMM_XZP_UKERNEL_FUNCTION(fn_name) \ QNNP_INTERNAL void fn_name( \ size_t mr, \ diff --git a/src/qnnpack/requantization.h b/src/qnnpack/requantization.h index 2b5fe6f..6c97ac6 100644 --- a/src/qnnpack/requantization.h +++ b/src/qnnpack/requantization.h @@ -197,6 +197,110 @@ static inline union qnnp_conv_quantization_params qnnp_compute_conv_quantization return params; } +static inline union qnnp_conv_quantization_params qnnp_compute_conv_quantization_params_per_channel( + uint8_t input_zero_point, + size_t kernel_params_size, // should be identical to group_output_channels + uint8_t* kernel_zero_point_v, + const float* scale_v, + int32_t* multiplier_v, // pre-allocated in operator-create + int32_t* right_shift_v, // pre-allocated in operator-create + uint8_t output_zero_point, + uint8_t output_min, + uint8_t output_max) +{ + const float scale = *scale_v; + const uint8_t kernel_zero_point = *kernel_zero_point_v; + /* Compute requantization parameters */ + const uint32_t scale_bits = fp32_to_bits(scale); + + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t multiplier = (int32_t)(((scale_bits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(multiplier >= INT32_C(0x40000000)); + assert(multiplier <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t shift = 127 + 31 - 32 - (fp32_to_bits(scale) >> 23); + assert(shift >= 0); + assert(shift < 32); + + union qnnp_conv_quantization_params params; + #if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.input_zero_point[i] = (int16_t) (uint16_t) input_zero_point; + params.sse2.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point; + } + params.sse2.multiplier[0] = multiplier; + params.sse2.multiplier[1] = multiplier; + params.sse2.multiplier[2] = multiplier; + params.sse2.multiplier[3] = multiplier; + params.sse2.rounding[0] = UINT64_C(0x40000000); + params.sse2.rounding[1] = UINT64_C(0x40000000); + params.sse2.remainder_mask[0] = (int32_t) remainder_mask; + params.sse2.remainder_mask[1] = (int32_t) remainder_mask; + params.sse2.remainder_mask[2] = (int32_t) remainder_mask; + params.sse2.remainder_mask[3] = (int32_t) remainder_mask; + params.sse2.remainder_threshold[0] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[1] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[2] = (int32_t) remainder_threshold; + params.sse2.remainder_threshold[3] = (int32_t) remainder_threshold; + params.sse2.shift[0] = (uint64_t) (uint32_t) shift; + params.sse2.shift[1] = (uint64_t) (uint32_t) shift; + for (uint32_t i = 0; i < 8; i++) { + params.sse2.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point; + } + for (uint32_t i = 0; i < 16; i++) { + params.sse2.output_max[i] = output_max; + params.sse2.output_min[i] = output_min; + } + #elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + params.neon.input_zero_point = (int16_t) (uint16_t) input_zero_point; + params.neon.kernel_zero_point = (int16_t) (uint16_t) kernel_zero_point; + params.neon.multiplier = multiplier; + params.neon.right_shift = -shift; + params.neon.output_zero_point = (int16_t) (uint16_t) output_zero_point; + params.neon.output_max = output_max; + params.neon.output_min = output_min; + params.neon.kernel_zero_point_v = kernel_zero_point_v; + params.neon.multiplier_v = multiplier_v; + params.neon.right_shift_v = right_shift_v; + for (uint32_t i = 0; i < kernel_params_size; ++i) { + const float s = scale_v[i]; + const uint8_t kzp = kernel_zero_point_v[i]; + /* Compute requantization parameters */ + const uint32_t sbits = fp32_to_bits(s); + /* Multiplier is in [0x40000000, 0x7FFFFF80] range */ + const int32_t m = (int32_t)(((sbits & UINT32_C(0x007FFFFF)) | UINT32_C(0x00800000)) << 7); + assert(m >= INT32_C(0x40000000)); + assert(m <= INT32_C(0x7FFFFF80)); + + /* Shift is in [0, 31] range */ + const int32_t rs = 127 + 31 - 32 - (fp32_to_bits(s) >> 23); + assert(rs >= 0); + assert(rs < 32); + params.neon.multiplier_v[i] = m; + params.neon.right_shift_v[i] = -rs; + } + + #else + const uint32_t remainder_mask = (UINT32_C(1) << shift) - UINT32_C(1); + const uint32_t remainder_threshold = remainder_mask >> 1; + params.scalar.input_zero_point = (int32_t) (uint32_t) input_zero_point; + params.scalar.kernel_zero_point = (int32_t) (uint32_t) kernel_zero_point; + params.scalar.multiplier = multiplier; + params.scalar.remainder_mask = (int32_t) remainder_mask; + params.scalar.remainder_threshold = (int32_t) remainder_threshold; + params.scalar.shift = (uint32_t) shift; + params.scalar.output_min_less_zero_point = + (int32_t) (uint32_t) output_min - (int32_t) (uint32_t) output_zero_point; + params.scalar.output_max_less_zero_point = + (int32_t) (uint32_t) output_max - (int32_t) (uint32_t) output_zero_point; + params.scalar.output_zero_point = (int32_t) (uint32_t) output_zero_point; + #endif + return params; +} + static inline union qnnp_avgpool_quantization_params qnnp_compute_avgpool_quantization_params( int32_t bias, float scale, diff --git a/test/gemm-microkernel-tester.h b/test/gemm-microkernel-tester.h index c01de36..9333e03 100644 --- a/test/gemm-microkernel-tester.h +++ b/test/gemm-microkernel-tester.h @@ -275,6 +275,128 @@ class GemmMicrokernelTester { } } + void test(q8gemm_per_channel_ukernel_function qgemm) const { + ASSERT_LE(m(), mr()); + ASSERT_LE(n(), nr()); + ASSERT_GE(k(), kr()); + + std::random_device randomDevice; + auto rng = std::mt19937(randomDevice()); + auto s32rng = std::bind(std::uniform_int_distribution(-10000, 10000), rng); + auto u8rng = std::bind(std::uniform_int_distribution(), rng); + + std::vector a((m() - 1) * aStride() + k() + 8); + std::vector b(n() * k()); + std::vector bias(n()); + std::vector> packedW(packedN() * packedK() + biasN() * sizeof(uint32_t) / sizeof(uint8_t)); + std::vector c((m() - 1) * cStride() + n()); + std::vector acc(m() * n()); + std::vector cRef(m() * n()); + + // Per-Channel quantization parameters + std::vector kernelZeroPointPerChannel(nr()); + std::vector kernelAndInputScalePerChannel(nr()); + std::vector requantizationScalePerChannel(nr()); + std::vector multiplierPerChannel(nr()); + std::vector rightShiftPerChannel(nr()); + + // 1) Fill zero-point per-channel around bZeroPoint() as center value. + // 2) Fill kernel-and-input per-channel using linear interpolation between min and max values. + // (Maintain: requantization_scale < 1 ; + // requantization_scale := input_scale * kernel_scale / output_scale) + const float scale_min = 0.5f; + const float scale_max = 0.99999f; + for (size_t i = 0; i < nr(); ++i) { + kernelZeroPointPerChannel[i] = + static_cast(std::min(255, std::max(0, bZeroPoint() + (int)(i - nr()/2)))); + kernelAndInputScalePerChannel[i] = scale_min + i * (scale_max - scale_min) / nr(); + } + + const uint8_t* aPtr = a.data() + 8; + + for (size_t iteration = 0; iteration < iterations(); iteration++) { + std::generate(a.begin(), a.end(), std::ref(u8rng)); + std::generate(b.begin(), b.end(), std::ref(u8rng)); + std::generate(bias.begin(), bias.end(), std::ref(s32rng)); + std::fill(c.begin(), c.end(), 0xA5); + + std::fill(packedW.begin(), packedW.end(), bZeroPoint()); + pack_q8gemm_w_per_channel(n(), k(), + nr(), np(), kr(), + aZeroPoint(), kernelZeroPointPerChannel.data(), + b.data(), bias.data(), packedW.data()); + + ASSERT_NE(*std::max_element(a.cbegin(), a.cend()), *std::min_element(a.cbegin(), a.cend())); + ASSERT_NE(*std::max_element(b.cbegin(), b.cend()), *std::min_element(b.cbegin(), b.cend())); + + /* Compute 32-bit results and output quantization arguments */ + std::fill(acc.begin(), acc.end(), 0); + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + for (size_t kIndex = 0; kIndex < k(); kIndex++) { + ASSERT_LE(n(), packedN()); + ASSERT_LT(mIndex * n() + nIndex, acc.size()); + ASSERT_LT(mIndex * k() + kIndex, a.size()); + acc[mIndex * n() + nIndex] += + (int32_t(aPtr[mIndex * aStride() + kIndex]) - int32_t(aZeroPoint())) * + (int32_t(b[nIndex * k() + kIndex]) - int32_t(kernelZeroPointPerChannel[nIndex])); + } + acc[mIndex * n() + nIndex] += bias[nIndex]; + } + } + + const int32_t accMin = *std::min_element(acc.cbegin(), acc.cend()); + const int32_t accMax = *std::max_element(acc.cbegin(), acc.cend()); + if (m() * n() >= 3) { + ASSERT_NE(accMax, accMin) + << "Mr x Nr x Kr = " << mr() << " x " << nr() << " x " << kr() + << ", M x N x K = " << m() << " x " << n() << " x " << k(); + } + + const double cScale = uint32_t(accMax - accMin) >= 256 ? double(uint32_t(accMax - accMin)) / 255.0 : 1.00001; + const uint8_t cZeroPoint = uint8_t(std::max(std::min( + lrint(127.5 - 0.5 * double(accMin + accMax) / cScale), + long(std::numeric_limits::max())), long(std::numeric_limits::min()))); + + for (size_t nIndex = 0; nIndex < nr(); nIndex++) { + requantizationScalePerChannel[nIndex] = kernelAndInputScalePerChannel[nIndex] / float(cScale); + } + const union qnnp_conv_quantization_params quantizationParams = + qnnp_compute_conv_quantization_params_per_channel( + aZeroPoint(), nr(), kernelZeroPointPerChannel.data(), + requantizationScalePerChannel.data(), multiplierPerChannel.data(), rightShiftPerChannel.data(), cZeroPoint, qmin(), qmax()); + + qgemm( + m(), n(), k(), + aPtr, aStride() * sizeof(uint8_t), + packedW.data(), + c.data(), cStride() * sizeof(uint8_t), + &quantizationParams, 0); + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + const union qnnp_q31_requantization_params scalarRequantizationParams = + qnnp_compute_scalar_requantization_params( + requantizationScalePerChannel[nIndex], cZeroPoint, qmin(), qmax()); + cRef[mIndex * n() + nIndex] = qnnp_q31_requantize(acc[mIndex * n() + nIndex], scalarRequantizationParams); + } + } + + for (size_t mIndex = 0; mIndex < m(); mIndex++) { + for (size_t nIndex = 0; nIndex < n(); nIndex++) { + ASSERT_LE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmax())); + ASSERT_GE(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(qmin())); + ASSERT_EQ(uint32_t(c[mIndex * cStride() + nIndex]), uint32_t(cRef[mIndex * n() + nIndex])) + << "at " << mIndex << ", " << nIndex << ": reference = " << (uint32_t) cRef[mIndex * n() + nIndex] + << " (accumulator = " << acc[mIndex * n() + nIndex] + << "), optimized = " << (uint32_t) c[mIndex * cStride() + nIndex] << ", Mr x Nr x Kr = " << mr() << " x " + << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k() + << ", requantization scale = " << requantizationScalePerChannel[nIndex] << ", output zero point = " << int32_t(cZeroPoint); + } + } + } + } + void test(q8conv_ukernel_function qconv) const { ASSERT_LE(m(), mr()); ASSERT_LE(n(), nr()); diff --git a/test/q8gemm.cc b/test/q8gemm.cc index 4eb77f0..35254bd 100644 --- a/test/q8gemm.cc +++ b/test/q8gemm.cc @@ -14,7 +14,6 @@ #include "gemm-microkernel-tester.h" - #if CPUINFO_ARCH_ARM TEST(Q8GEMM_4x8__AARCH32_NEON, k_eq_8) { TEST_REQUIRES_ARM_NEON; @@ -605,209 +604,222 @@ } } } -#endif -#if CPUINFO_ARCH_ARM64 - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_a) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aStride(37) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_c) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .cStride(17) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmin128) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_qmin128_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .qmin(128) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmax128) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_qmax128_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .qmax(128) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_azp0) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_bzp0) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_nozp) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_eq_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_a) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aStride(37) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_c) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_azp0) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_azp0_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_bzp0) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_bzp0_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_nozp) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_nozp_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_subtile) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_gt_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { - for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) @@ -815,62 +827,66 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_a) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_strided_a_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aStride(171) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_c) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_strided_c_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } - TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_subtile) { + TEST(Q8GEMM_4x8__AARCH32_NEON_PER_CHANNEL, k_div_8_subtile_per_channel) { + TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 24) { - for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) @@ -878,229 +894,214 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_8x8__aarch64_neon); + .test(q8gemm_ukernel_4x8__aarch32_neon_per_channel); } } } } #endif -#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 - TEST(Q8GEMM_4x8__NEON, k_eq_8) { - TEST_REQUIRES_ARM_NEON; +#if CPUINFO_ARCH_ARM64 + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_a) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .aStride(37) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_strided_c) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .cStride(17) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmin128) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .qmin(128) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_qmax128) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .qmax(128) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_azp0) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_bzp0) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .bZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_eq_8_nozp) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } - TEST(Q8GEMM_4x8__NEON, k_gt_8) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_a) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_a) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .aStride(37) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_c) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_strided_c) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_azp0) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_azp0) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_bzp0) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_bzp0) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_nozp) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_nozp) { for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_gt_8_subtile) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_gt_8_subtile) { for (size_t k = 9; k < 16; k++) { - for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t m = 1; m <= 8; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) @@ -1108,66 +1109,62 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } } } - TEST(Q8GEMM_4x8__NEON, k_div_8) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8) { for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_div_8_strided_a) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_a) { for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .aStride(171) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_div_8_strided_c) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_strided_c) { for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) - .m(4) + .m(8) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } - TEST(Q8GEMM_4x8__NEON, k_div_8_subtile) { - TEST_REQUIRES_ARM_NEON; + TEST(Q8GEMM_8x8__AARCH64_NEON, k_div_8_subtile) { for (size_t k = 16; k < 128; k += 24) { - for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t m = 1; m <= 8; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(4) + .mr(8) .nr(8) .np(8) .kr(1) @@ -1175,227 +1172,229 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_4x8__neon); + .test(q8gemm_ukernel_8x8__aarch64_neon); } } } } +#endif - TEST(Q8GEMM_8x8__NEON, k_eq_8) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + TEST(Q8GEMM_4x8__NEON, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aStride(37) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .cStride(17) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_qmin128) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .qmin(128) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_qmax128) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .qmax(128) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_azp0) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_bzp0) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_eq_8_nozp) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } - TEST(Q8GEMM_8x8__NEON, k_gt_8) { + TEST(Q8GEMM_4x8__NEON, k_gt_8) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aStride(37) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_c) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_azp0) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_azp0) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_bzp0) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_bzp0) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_nozp) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_nozp) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_gt_8_subtile) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { - for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) @@ -1403,66 +1402,66 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } } } - TEST(Q8GEMM_8x8__NEON, k_div_8) { + TEST(Q8GEMM_4x8__NEON, k_div_8) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_div_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .aStride(171) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_div_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_c) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) - .m(8) + .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } - TEST(Q8GEMM_8x8__NEON, k_div_8_subtile) { + TEST(Q8GEMM_4x8__NEON, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 24) { - for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t m = 1; m <= 4; m++) { for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(8) + .mr(4) .nr(8) .np(8) .kr(1) @@ -1470,96 +1469,308 @@ .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_8x8__neon); + .test(q8gemm_ukernel_4x8__neon); } } } } - TEST(Q8GEMM_6x4__NEON, k_eq_8) { + TEST(Q8GEMM_8x8__NEON, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) + .mr(8) + .nr(8) + .np(8) .kr(1) - .m(6) - .n(4) + .m(8) + .n(8) .k(8) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_a) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) + .mr(8) + .nr(8) + .np(8) .kr(1) - .m(6) - .n(4) + .m(8) + .n(8) .k(8) .aStride(37) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_c) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_strided_c) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) + .mr(8) + .nr(8) + .np(8) .kr(1) - .m(6) - .n(4) + .m(8) + .n(8) .k(8) .cStride(17) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_qmin128) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_qmin128) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) .k(8) .qmin(128) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_qmax128) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_qmax128) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) + .mr(8) + .nr(8) + .np(8) .kr(1) - .m(6) - .n(4) + .m(8) + .n(8) .k(8) .qmax(128) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_azp0) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_azp0) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) + .mr(8) + .nr(8) + .np(8) .kr(1) - .m(6) - .n(4) + .m(8) + .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_ukernel_8x8__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_bzp0) { + TEST(Q8GEMM_8x8__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .bZeroPoint(0) + .test(q8gemm_ukernel_8x8__neon); + } + + TEST(Q8GEMM_8x8__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_8x8__neon); + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(37) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .bZeroPoint(0) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_ukernel_8x8__neon); + } + } + } + } + + TEST(Q8GEMM_8x8__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .aStride(171) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(8) + .n(8) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_8x8__neon); + } + } + + TEST(Q8GEMM_8x8__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 8; m++) { + for (uint32_t n = 1; n <= 8; n++) { + GemmMicrokernelTester() + .mr(8) + .nr(8) + .np(8) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_ukernel_8x8__neon); + } + } + } + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(6) @@ -1569,11 +1780,10 @@ .m(6) .n(4) .k(8) - .bZeroPoint(0) .test(q8gemm_ukernel_6x4__neon); } - TEST(Q8GEMM_6x4__NEON, k_eq_8_nozp) { + TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_a) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(6) @@ -1583,402 +1793,781 @@ .m(6) .n(4) .k(8) + .aStride(37) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .cStride(17) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .qmin(128) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .qmax(128) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .aZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .bZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(8) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aStride(37) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_azp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .bZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_nozp) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aZeroPoint(0) + .bZeroPoint(0) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_gt_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 9; k < 16; k++) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(m) + .n(n) + .k(k) + .iterations(3) + .test(q8gemm_ukernel_6x4__neon); + } + } + } + } + + TEST(Q8GEMM_6x4__NEON, k_div_8) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_div_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .aStride(171) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_div_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 8) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(6) + .n(4) + .k(k) + .cStride(17) + .test(q8gemm_ukernel_6x4__neon); + } + } + + TEST(Q8GEMM_6x4__NEON, k_div_8_subtile) { + TEST_REQUIRES_ARM_NEON; + for (size_t k = 16; k < 128; k += 24) { + for (uint32_t m = 1; m <= 6; m++) { + for (uint32_t n = 1; n <= 4; n++) { + GemmMicrokernelTester() + .mr(6) + .nr(4) + .np(4) + .kr(1) + .m(m) + .n(n) + .k(k) + .test(q8gemm_ukernel_6x4__neon); + } + } + } + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_a) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aStride(37) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_c) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .cStride(17) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmin128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .qmin(128) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmax128) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .qmax(128) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_azp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .aZeroPoint(0) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_bzp0) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) + .bZeroPoint(0) + .test(q8gemm_xzp_ukernel_4x8c2__neon); + } + + TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_nozp) { + TEST_REQUIRES_ARM_NEON; + GemmMicrokernelTester() + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) + .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } - TEST(Q8GEMM_6x4__NEON, k_gt_8) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_a) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .aStride(37) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_strided_c) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_c) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_azp0) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_azp0) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_bzp0) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_bzp0) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_nozp) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_nozp) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_gt_8_subtile) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_subtile) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { - for (uint32_t m = 1; m <= 6; m++) { - for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) + .mr(4) + .nr(8) + .np(8) + .kr(2) .m(m) .n(n) .k(k) .iterations(3) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } } } - TEST(Q8GEMM_6x4__NEON, k_div_8) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_div_8_strided_a) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_a) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .aStride(171) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_div_8_strided_c) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_c) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) - .m(6) - .n(4) + .mr(4) + .nr(8) + .np(8) + .kr(2) + .m(4) + .n(8) .k(k) .cStride(17) - .test(q8gemm_ukernel_6x4__neon); + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } - TEST(Q8GEMM_6x4__NEON, k_div_8_subtile) { + TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_subtile) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 24) { - for (uint32_t m = 1; m <= 6; m++) { - for (uint32_t n = 1; n <= 4; n++) { + for (uint32_t m = 1; m <= 4; m++) { + for (uint32_t n = 1; n <= 8; n++) { GemmMicrokernelTester() - .mr(6) - .nr(4) - .np(4) - .kr(1) + .mr(4) + .nr(8) + .np(8) + .kr(2) .m(m) .n(n) .k(k) - .test(q8gemm_ukernel_6x4__neon); + .iterations(3) + .test(q8gemm_xzp_ukernel_4x8c2__neon); } } } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_a_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .aStride(37) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_strided_c_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .cStride(17) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmin128) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmin128_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .qmin(128) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_qmax128) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_qmax128_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .qmax(128) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_azp0) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_azp0_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .aZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_bzp0) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_bzp0_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .bZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_eq_8_nozp) { + TEST(Q8GEMM_4x8__NEON, k_eq_8_nozp_per_channel) { TEST_REQUIRES_ARM_NEON; GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(8) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_a_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .aStride(37) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_strided_c_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_azp0) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_azp0_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .aZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_bzp0) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_bzp0_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .bZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_nozp) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_nozp_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .aZeroPoint(0) .bZeroPoint(0) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_gt_8_subtile) { + TEST(Q8GEMM_4x8__NEON, k_gt_8_subtile_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 9; k < 16; k++) { for (uint32_t m = 1; m <= 4; m++) { @@ -1987,65 +2576,65 @@ .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(m) .n(n) .k(k) .iterations(3) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8) { + TEST(Q8GEMM_4x8__NEON, k_div_8_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_a) { + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_a_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .aStride(171) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_strided_c) { + TEST(Q8GEMM_4x8__NEON, k_div_8_strided_c_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 8) { GemmMicrokernelTester() .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(4) .n(8) .k(k) .cStride(17) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } - TEST(Q8GEMM_4x8c2_XZP__NEON, k_div_8_subtile) { + TEST(Q8GEMM_4x8__NEON, k_div_8_subtile_per_channel) { TEST_REQUIRES_ARM_NEON; for (size_t k = 16; k < 128; k += 24) { for (uint32_t m = 1; m <= 4; m++) { @@ -2054,12 +2643,12 @@ .mr(4) .nr(8) .np(8) - .kr(2) + .kr(1) .m(m) .n(n) .k(k) .iterations(3) - .test(q8gemm_xzp_ukernel_4x8c2__neon); + .test(q8gemm_ukernel_4x8__neon_per_channel); } } }