diff --git a/ggml b/ggml new file mode 160000 index 000000000000..8be60f83ec12 --- /dev/null +++ b/ggml @@ -0,0 +1 @@ +Subproject commit 8be60f83ec124c31f3a427053c29022e3072f8a4 diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index a319913c3f9e..ff26193748e2 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -5,127 +5,6172 @@ #include "layer_type.h" +#include "cpu.h" +#include "x86_usability.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#include +#include +#include +#include + namespace ncnn { -SDPA_x86::SDPA_x86() -{ -#if NCNN_BF16 - support_bf16_storage = true; -#endif +SDPA_x86::SDPA_x86() +{ +#if NCNN_BF16 + support_bf16_storage = true; +#endif + cached_kv_seqlen = -1; + cached_num_group = 0; + cached_embed_dim = 0; + cached_out_embed_dim = 0; +} + +int SDPA_x86::create_pipeline(const Option& /*_opt*/) +{ + if (int8_scale_term) + { + support_bf16_storage = false; + } + + return 0; +} + +int SDPA_x86::destroy_pipeline(const Option& /*_opt*/) +{ + return 0; +} + +#include "sdpa_x86_int8.h" + +#if NCNN_BF16 +#include "sdpa_x86_bf16s.h" +#endif + +static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + for (int j = 0; j < n; j++) + { + const float* kptr = K + j * d; + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + +static inline void vec_scale(float* x, float s, int n) +{ +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(s); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale512)); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale512)); + } +#else + int i = 0; +#if __SSE2__ +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(s); + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale256)); +#endif // __AVX__ + __m128 vscale128 = _mm_set1_ps(s); + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale128)); +#endif // __SSE2__ + for (; i < n; i++) + x[i] *= s; +#endif // __AVX512F__ +} + +static inline void vec_zero(float* x, int n) +{ +#if __AVX512F__ + __m512 zero512 = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, zero512); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, zero512); + } +#else + int i = 0; +#if __SSE2__ +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, zero256); +#endif // __AVX__ + __m128 zero128 = _mm_setzero_ps(); + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, zero128); +#endif // __SSE2__ + for (; i < n; i++) + x[i] = 0.f; +#endif // __AVX512F__ +} + +static inline void softmax_tile(float* P, const float* S, + float* m_vec, float* l_vec, float* scale_out, unsigned char* changed_out, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); + vmax = _mm512_max_ps(vmax, tail); + } + float m_new = _mm512_comp_reduce_max_ps(vmax); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; + l_vec[i] *= scale_factor; + + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < n; j += 16) + { + __m512 svec = _mm512_loadu_ps(sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_storeu_ps(pptr + j, evec); + vsum = _mm512_add_ps(vsum, evec); + } + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_mask_storeu_ps(pptr + j, mask, evec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); + } + float l_add = _mm512_comp_reduce_add_ps(vsum); + l_vec[i] += l_add; + m_vec[i] = m_new; +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; + l_vec[i] *= scale_factor; + + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < n; j += 8) + { + __m256 svec = _mm256_loadu_ps(sptr + j); + __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); + _mm256_storeu_ps(pptr + j, evec); + vsum = _mm256_add_ps(vsum, evec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(sptr + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; + l_vec[i] *= scale_factor; + + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < n; j += 4) + { + __m128 svec = _mm_loadu_ps(sptr + j); + __m128 evec = exp_ps(_mm_sub_ps(svec, vm_new)); + _mm_storeu_ps(pptr + j, evec); + vsum = _mm_add_ps(vsum, evec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#else + float m_new = m_vec[i]; + for (int j = 0; j < n; j++) + m_new = std::max(m_new, sptr[j]); + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; + l_vec[i] *= scale_factor; + float l_add = 0.f; + for (int j = 0; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#endif + } +} + +static inline void pv_gemm_scalar(float* O, const float* P, const float* V, + int m, int n, int d) +{ + for (int i = 0; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + { + float p = pptr[j]; + const float* vptr = V + j * d; + for (int k = 0; k < d; k++) + optr[k] += p * vptr[k]; + } + } +} + +static inline void decode_pv_gemv(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +{ +#if __AVX512F__ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + if (j + 6 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 6) * out_d), _MM_HINT_T1); + + __m512 pvec0 = _mm512_set1_ps(s[j]); + __m512 pvec1 = _mm512_set1_ps(s[j + 1]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v00 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v01 = _mm512_loadu_ps(V + (n_start + j) * out_d + k + 16); + __m512 v10 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k); + __m512 v11 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k + 16); + oval0 = _mm512_fmadd_ps(pvec0, v00, oval0); + oval1 = _mm512_fmadd_ps(pvec0, v01, oval1); + oval0 = _mm512_fmadd_ps(pvec1, v10, oval0); + oval1 = _mm512_fmadd_ps(pvec1, v11, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 v0 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_storeu_ps(out + k, oval); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 v0 = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j + 1) * out_d + k); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_mask_storeu_ps(out + k, mask_d, oval); + } + } + for (; j < block_n; j++) + { + __m512 pvec512 = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v0 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_loadu_ps(V + (n_start + j) * out_d + k + 16); + oval0 = _mm512_fmadd_ps(pvec512, v0, oval0); + oval1 = _mm512_fmadd_ps(pvec512, v1, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec512, vval, oval)); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); + } + } +#else + int j = 0; + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + + int k = 0; +#if __SSE2__ +#if __AVX__ + __m256 pvec256 = _mm256_set1_ps(s[j]); + for (; k + 7 < out_d; k += 8) + { + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec256, vval, oval)); + } +#endif // __AVX__ + __m128 pvec128 = _mm_set1_ps(s[j]); + for (; k + 3 < out_d; k += 4) + { + __m128 oval = _mm_loadu_ps(out + k); + __m128 vval = _mm_loadu_ps(V + (n_start + j) * out_d + k); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec128, vval, oval)); + } +#endif // __SSE2__ + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; + } +#endif // __AVX512F__ +} + +static inline void decode_qk_dot(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +{ +#if __AVX512F__ + int j = 0; + if (d >= 256) + { + for (; j + 7 < block_n; j += 8) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + const float* k4 = K + (n_start + j + 4) * d; + const float* k5 = K + (n_start + j + 5) * d; + const float* k6 = K + (n_start + j + 6) * d; + const float* k7 = K + (n_start + j + 7) * d; + + if (j + 15 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 8) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 9) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 10) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 11) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 12) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 13) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 14) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 15) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + __m512 acc4 = _mm512_setzero_ps(); + __m512 acc5 = _mm512_setzero_ps(); + __m512 acc6 = _mm512_setzero_ps(); + __m512 acc7 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + acc4 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k4 + k), acc4); + acc5 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k5 + k), acc5); + acc6 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k6 + k), acc6); + acc7 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k7 + k), acc7); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + acc4 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k4 + k), acc4); + acc5 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k5 + k), acc5); + acc6 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k6 + k), acc6); + acc7 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k7 + k), acc7); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + s[j + 4] = _mm512_comp_reduce_add_ps(acc4) * scale; + s[j + 5] = _mm512_comp_reduce_add_ps(acc5) * scale; + s[j + 6] = _mm512_comp_reduce_add_ps(acc6) * scale; + s[j + 7] = _mm512_comp_reduce_add_ps(acc7) * scale; + } + } + + for (; j + 3 < block_n; j += 4) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + + if (j + 7 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 5) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 6) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 7) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +#elif __AVX__ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + + if (j + 5 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 3) * d), _MM_HINT_T1); + } + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + sum0 += q[k] * k0[k]; + sum1 += q[k] * k1[k]; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + + for (; j < block_n; j++) + { + if (j + 2 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + + const float* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); + int k = 0; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * kptr[k]; + s[j] = sum * scale; + } +#elif __SSE2__ + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m128 acc = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), _mm_loadu_ps(K + (n_start + j) * d + k), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } +#else + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } +#endif +} + +static inline void decode_mask_vec(float* s, const float* mask, int block_n) +{ +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + j))); + for (; j < block_n; j++) + s[j] += mask[j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + j))); + for (; j < block_n; j++) + s[j] += mask[j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[j]; +#endif +} + +static inline float decode_max_vec(const float* s, int block_n) +{ +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s + j)); + } + return _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#else + float m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#endif +} + +static inline void decode_exp_sum_vec(float* s, int block_n, float new_m, float* l_add) +{ +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + int j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + *l_add = _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + int j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + int j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#else + float l = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#endif +} + +static void sdpa_decode_chunk( + float* out, float* m_out, float* l_out, + const float* q, const float* K, const float* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale); + +static void sdpa_decode(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + float m, l; + sdpa_decode_chunk(out, &m, &l, q, K, V, mask, 0, n, d, out_d, scale); + float inv_l = 1.f / l; + vec_scale(out, inv_l, out_d); +} + +static void sdpa_decode_chunk( + float* out, float* m_out, float* l_out, + const float* q, const float* K, const float* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n = n_start; n < n_end; n += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n_end - n); + + decode_qk_dot(s, q, K, n, block_n, d, scale); + + if (mask) + decode_mask_vec(s, mask + n, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(out, scale_factor, out_d); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + + decode_pv_gemv(out, s, V, n, block_n, out_d); + + m = new_m; + } + + *m_out = m; + *l_out = l; +} + +static void sdpa_decode_reduce( + float* out, int out_d, + const float* partials, int num_chunks, int partial_stride) +{ + float M_final = -FLT_MAX; + float S_final = 0.f; + vec_zero(out, out_d); + + for (int c = 0; c < num_chunks; c++) + { + const float* p = partials + c * partial_stride; + float M_chunk = p[0]; + float S_chunk = p[1]; + if (S_chunk == 0.f) continue; + + const float* VKQ_chunk = p + 2; + + float M_new = std::max(M_final, M_chunk); + float scale_final = expf(M_final - M_new); + float scale_chunk = expf(M_chunk - M_new); + + for (int k = 0; k < out_d; k++) + { + out[k] = out[k] * scale_final + VKQ_chunk[k] * scale_chunk; + } + + S_final = S_final * scale_final + S_chunk * scale_chunk; + M_final = M_new; + } + + if (S_final != 0.f) + { + float inv_s = 1.f / S_final; + vec_scale(out, inv_s, out_d); + } +} + +static inline void sdpa_int8_decode_core( + float* s, float* out, float* m, float* l, + const signed char* q_int8, const float* q_scale, + const signed char* key_int8, const float* key_scales, + const signed char* value_int8, const float* value_scales, + const float* mask_ptr, + int n_start, int block_n, int embed_dim, int out_embed_dim, float scale) +{ + decode_qk_dot_int8(s, q_int8, key_int8, q_scale, key_scales, n_start, block_n, embed_dim, scale); + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(*m, tile_m); + if (*m != new_m) + { + float scale_factor = expf(*m - new_m); + *l *= scale_factor; + vec_scale(out, scale_factor, out_embed_dim); + } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + *l += l_add; + decode_pv_gemv_int8(out, s, value_int8, value_scales, n_start, block_n, out_embed_dim); + *m = new_m; +} + +#if __AVX512F__ + +static void qk_gemm_avx512(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + const int tail = d & 15; + const __mmask16 tail_mask = (__mmask16)((1u << tail) - 1); + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (tail) + { + __m512 kv0 = _mm512_maskz_loadu_ps(tail_mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(tail_mask, k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (tail) + { + __m512 kvec = _mm512_maskz_loadu_ps(tail_mask, kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (tail) + { + __m512 kv0 = _mm512_maskz_loadu_ps(tail_mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(tail_mask, k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (tail) + { + __m512 kvec = _mm512_maskz_loadu_ps(tail_mask, kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * d; + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + const float* k2 = K + (j + 2) * d; + const float* k3 = K + (j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + if (tail) + { + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + int k = 0; + __m512 vacc = _mm512_setzero_ps(); + for (; k + 15 < d; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + if (tail) + { + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(tail_mask, qptr + k), _mm512_maskz_loadu_ps(tail_mask, kptr + k), vacc); + } + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + +template +static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + const float* kptr = K + j * D; + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < D; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + +template +static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + if (M_BLOCK2 > 0) + { + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + if (D >= 512 && n >= 128 && j + 8 <= n) + { + _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 6) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 7) * D), _MM_HINT_T1); + } + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m512 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + } + + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + if (D >= 512 && n >= 128 && j + 8 <= n) + { + _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 6) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 7) * D), _MM_HINT_T1); + } + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m512 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < D; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<128, 4, 6>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx512<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<512, 4, 8>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<2048, 2, 0>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<1024, 4, 4>(S, Q, K, m, n, scale); +} + +template<> +void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); +} + +template +static inline void qk_gemm_specialized_avx512_large_m(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ +} + +template<> +void qk_gemm_specialized_avx512_large_m<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<4096, 4, 6>(S, Q, K, m, n, scale); +} + +template +static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 16; + const bool prefetch_v = d >= 512 && n >= 256; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + vi * 16); + + for (int j = 0; j < n; j++) + { + if (prefetch_v && j + 4 < n) + _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm512_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(op[mi] + vi * 16, acc[mi][vi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + + __m512 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + vi * 16); + + for (int j = 0; j < n; j++) + { + if (prefetch_v && j + 4 < n) + _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); + __m512 pvec = _mm512_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(optr + vi * 16, acc[vi]); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + if (INIT_ZERO) + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] = 0.f; + } + for (int j = 0; j < n; j++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] += pptr[mi][j] * V[j * d + dd]; + } + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + if (INIT_ZERO) + optr[0] = 0.f; + for (int j = 0; j < n; j++) + optr[0] += pptr[j] * V[j * d + dd]; + } + } +} + +template +static inline void pv_gemm_2heads_4096_avx512(float* O0, float* O1, const float* P0, const float* P1, const float* V, int m, int n) +{ + const int d = 4096; + const int VEC_PER_UNROLL = D_UNROLL / 16; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op0[M_BLOCK]; + float* op1[M_BLOCK]; + const float* pptr0[M_BLOCK]; + const float* pptr1[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op0[mi] = O0 + (i + mi) * d + dd; + op1[mi] = O1 + (i + mi) * d + dd; + pptr0[mi] = P0 + (i + mi) * n; + pptr1[mi] = P1 + (i + mi) * n; + } + + __m512 acc0[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc1[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op0[mi] + vi * 16); + acc1[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op1[mi] + vi * 16); + } + } + + for (int j = 0; j < n; j++) + { + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[mi][j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = _mm512_fmadd_ps(pvec0, vvec[vi], acc0[mi][vi]); + acc1[mi][vi] = _mm512_fmadd_ps(pvec1, vvec[vi], acc1[mi][vi]); + } + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(op0[mi] + vi * 16, acc0[mi][vi]); + _mm512_storeu_ps(op1[mi] + vi * 16, acc1[mi][vi]); + } + } + } + + for (; i < m; i++) + { + float* optr0 = O0 + i * d + dd; + float* optr1 = O1 + i * d + dd; + const float* pptr0 = P0 + i * n; + const float* pptr1 = P1 + i * n; + + __m512 acc0[VEC_PER_UNROLL]; + __m512 acc1[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr0 + vi * 16); + acc1[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr1 + vi * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); + acc0[vi] = _mm512_fmadd_ps(pvec0, vvec, acc0[vi]); + acc1[vi] = _mm512_fmadd_ps(pvec1, vvec, acc1[vi]); + } + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(optr0 + vi * 16, acc0[vi]); + _mm512_storeu_ps(optr1 + vi * 16, acc1[vi]); + } + } + } +} + +template +static inline void pv_gemm_4heads_4096_avx512(float* O0, float* O1, float* O2, float* O3, const float* P0, const float* P1, const float* P2, const float* P3, const float* V, int m, int n) +{ + const int d = 4096; + const int VEC_PER_UNROLL = D_UNROLL / 16; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op0[M_BLOCK]; + float* op1[M_BLOCK]; + float* op2[M_BLOCK]; + float* op3[M_BLOCK]; + const float* pptr0[M_BLOCK]; + const float* pptr1[M_BLOCK]; + const float* pptr2[M_BLOCK]; + const float* pptr3[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op0[mi] = O0 + (i + mi) * d + dd; + op1[mi] = O1 + (i + mi) * d + dd; + op2[mi] = O2 + (i + mi) * d + dd; + op3[mi] = O3 + (i + mi) * d + dd; + pptr0[mi] = P0 + (i + mi) * n; + pptr1[mi] = P1 + (i + mi) * n; + pptr2[mi] = P2 + (i + mi) * n; + pptr3[mi] = P3 + (i + mi) * n; + } + + __m512 acc0[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc1[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc2[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc3[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op0[mi] + vi * 16); + acc1[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op1[mi] + vi * 16); + acc2[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op2[mi] + vi * 16); + acc3[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op3[mi] + vi * 16); + } + } + + for (int j = 0; j < n; j++) + { + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[mi][j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[mi][j]); + __m512 pvec2 = _mm512_set1_ps(pptr2[mi][j]); + __m512 pvec3 = _mm512_set1_ps(pptr3[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = _mm512_fmadd_ps(pvec0, vvec[vi], acc0[mi][vi]); + acc1[mi][vi] = _mm512_fmadd_ps(pvec1, vvec[vi], acc1[mi][vi]); + acc2[mi][vi] = _mm512_fmadd_ps(pvec2, vvec[vi], acc2[mi][vi]); + acc3[mi][vi] = _mm512_fmadd_ps(pvec3, vvec[vi], acc3[mi][vi]); + } + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(op0[mi] + vi * 16, acc0[mi][vi]); + _mm512_storeu_ps(op1[mi] + vi * 16, acc1[mi][vi]); + _mm512_storeu_ps(op2[mi] + vi * 16, acc2[mi][vi]); + _mm512_storeu_ps(op3[mi] + vi * 16, acc3[mi][vi]); + } + } + } + + for (; i < m; i++) + { + float* optr0 = O0 + i * d + dd; + float* optr1 = O1 + i * d + dd; + float* optr2 = O2 + i * d + dd; + float* optr3 = O3 + i * d + dd; + const float* pptr0 = P0 + i * n; + const float* pptr1 = P1 + i * n; + const float* pptr2 = P2 + i * n; + const float* pptr3 = P3 + i * n; + + __m512 acc0[VEC_PER_UNROLL]; + __m512 acc1[VEC_PER_UNROLL]; + __m512 acc2[VEC_PER_UNROLL]; + __m512 acc3[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr0 + vi * 16); + acc1[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr1 + vi * 16); + acc2[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr2 + vi * 16); + acc3[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr3 + vi * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[j]); + __m512 pvec2 = _mm512_set1_ps(pptr2[j]); + __m512 pvec3 = _mm512_set1_ps(pptr3[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); + acc0[vi] = _mm512_fmadd_ps(pvec0, vvec, acc0[vi]); + acc1[vi] = _mm512_fmadd_ps(pvec1, vvec, acc1[vi]); + acc2[vi] = _mm512_fmadd_ps(pvec2, vvec, acc2[vi]); + acc3[vi] = _mm512_fmadd_ps(pvec3, vvec, acc3[vi]); + } + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(optr0 + vi * 16, acc0[vi]); + _mm512_storeu_ps(optr1 + vi * 16, acc1[vi]); + _mm512_storeu_ps(optr2 + vi * 16, acc2[vi]); + _mm512_storeu_ps(optr3 + vi * 16, acc3[vi]); + } + } + } +} + +template +static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) +{ + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc[M_BLOCK][8]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + dd + 0 * 16); + acc[mi][1] = _mm512_loadu_ps(op[mi] + dd + 1 * 16); + acc[mi][2] = _mm512_loadu_ps(op[mi] + dd + 2 * 16); + acc[mi][3] = _mm512_loadu_ps(op[mi] + dd + 3 * 16); + acc[mi][4] = _mm512_loadu_ps(op[mi] + dd + 4 * 16); + acc[mi][5] = _mm512_loadu_ps(op[mi] + dd + 5 * 16); + acc[mi][6] = _mm512_loadu_ps(op[mi] + dd + 6 * 16); + acc[mi][7] = _mm512_loadu_ps(op[mi] + dd + 7 * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 v0 = _mm512_loadu_ps(V + j * D + dd + 0 * 16); + __m512 v1 = _mm512_loadu_ps(V + j * D + dd + 1 * 16); + __m512 v2 = _mm512_loadu_ps(V + j * D + dd + 2 * 16); + __m512 v3 = _mm512_loadu_ps(V + j * D + dd + 3 * 16); + __m512 v4 = _mm512_loadu_ps(V + j * D + dd + 4 * 16); + __m512 v5 = _mm512_loadu_ps(V + j * D + dd + 5 * 16); + __m512 v6 = _mm512_loadu_ps(V + j * D + dd + 6 * 16); + __m512 v7 = _mm512_loadu_ps(V + j * D + dd + 7 * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); + acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); + acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); + acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); + acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + _mm512_storeu_ps(op[mi] + dd + 0 * 16, acc[mi][0]); + _mm512_storeu_ps(op[mi] + dd + 1 * 16, acc[mi][1]); + _mm512_storeu_ps(op[mi] + dd + 2 * 16, acc[mi][2]); + _mm512_storeu_ps(op[mi] + dd + 3 * 16, acc[mi][3]); + _mm512_storeu_ps(op[mi] + dd + 4 * 16, acc[mi][4]); + _mm512_storeu_ps(op[mi] + dd + 5 * 16, acc[mi][5]); + _mm512_storeu_ps(op[mi] + dd + 6 * 16, acc[mi][6]); + _mm512_storeu_ps(op[mi] + dd + 7 * 16, acc[mi][7]); + } + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * D + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < D; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * D + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); + acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); + acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); + acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); + acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); + acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); + acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); + acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); + } + + _mm512_storeu_ps(optr + dd + 0 * 16, acc0); + _mm512_storeu_ps(optr + dd + 1 * 16, acc1); + _mm512_storeu_ps(optr + dd + 2 * 16, acc2); + _mm512_storeu_ps(optr + dd + 3 * 16, acc3); + _mm512_storeu_ps(optr + dd + 4 * 16, acc4); + _mm512_storeu_ps(optr + dd + 5 * 16, acc5); + _mm512_storeu_ps(optr + dd + 6 * 16, acc6); + _mm512_storeu_ps(optr + dd + 7 * 16, acc7); + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc = _mm512_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); + _mm512_storeu_ps(optr + dd, acc); + } + + for (; dd < D; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * D + dd]; + optr[dd] = acc; + } + } +} + +#endif // __AVX512F__ + +#if __AVX__ + +static void qk_gemm_avx(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum0 = _mm256_reduce_add_ps(acc[mi][0]); + float sum1 = _mm256_reduce_add_ps(acc[mi][1]); + + for (; k < d; k++) + { + float qv = Q[(i + mi) * d + k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum = _mm256_reduce_add_ps(acc[mi]); + for (; k < d; k++) + sum += Q[(i + mi) * d + k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + +template +static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < D; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + +template +static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + if (M_BLOCK2 > 0) + { + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m256 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m256 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + } + + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m256 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m256 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); + acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < D; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<512, 2, 4>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<1024, 4, 2>(S, Q, K, m, n, scale); +} + +template<> +void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<2048, 2, 0>(S, Q, K, m, n, scale); +} + +template<> +void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<4096, 2, 2>(S, Q, K, m, n, scale); +} + +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m256 acc[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm256_loadu_ps(V + j * d + dd + vi * 8); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm256_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[mi][vi]); + } + + for (; dd + 7 < d; dd += 8) + { + __m256 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m256 vvec = _mm256_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm256_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < d; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m256 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); + } + + for (; dd + 7 < d; dd += 8) + { + __m256 acc = _mm256_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; + } + } +} + +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(optr + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(optr + vi * 8, acc[vi]); + } +} + +#endif // __AVX__ + +#if __SSE2__ + +static void qk_gemm_sse2(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m128 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm_setzero_ps(); + + int k = 0; + for (; k + 3 < d; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + float sum = _mm_reduce_add_ps(acc[mi]); + for (; k < d; k++) + sum += Q[(i + mi) * d + k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + for (int j = 0; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m128 vacc = _mm_setzero_ps(); + for (; k + 3 < d; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + sum = _mm_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + +template +static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + for (int j = 0; j < n; j++) + { + const float* kptr = K + j * D; + + const float* q0 = Q + (i + 0) * D; + const float* q1 = Q + (i + 1) * D; + const float* q2 = Q + (i + 2) * D; + const float* q3 = Q + (i + 3) * D; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + + __m128 qvec = _mm_loadu_ps(q0 + k); + acc0 = _mm_comp_fmadd_ps(qvec, kvec, acc0); + + qvec = _mm_loadu_ps(q1 + k); + acc1 = _mm_comp_fmadd_ps(qvec, kvec, acc1); + + qvec = _mm_loadu_ps(q2 + k); + acc2 = _mm_comp_fmadd_ps(qvec, kvec, acc2); + + qvec = _mm_loadu_ps(q3 + k); + acc3 = _mm_comp_fmadd_ps(qvec, kvec, acc3); + } + + S[(i + 0) * n + j] = _mm_reduce_add_ps(acc0) * scale; + S[(i + 1) * n + j] = _mm_reduce_add_ps(acc1) * scale; + S[(i + 2) * n + j] = _mm_reduce_add_ps(acc2) * scale; + S[(i + 3) * n + j] = _mm_reduce_add_ps(acc3) * scale; + } + } + + for (; i < m; i++) + { + for (int j = 0; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < D; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; + } + } +} + +template +static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + if (M_BLOCK2 > 0) + { + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m128 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); + } + + for (int k = 0; k < D; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m128 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + } + + for (int k = 0; k < D; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m128 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; + } + } + } + + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m128 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); + } + + for (int k = 0; k < D; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m128 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + } + + for (int k = 0; k < D; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m128 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) + acc[mi] = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK1; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < M_BLOCK1; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); + acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < D; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_sse2<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<512, 2, 4>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<1024, 2, 4>(S, Q, K, m, n, scale); +} + +template<> +void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<2048, 2, 0>(S, Q, K, m, n, scale); +} + +template<> +void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<4096, 2, 2>(S, Q, K, m, n, scale); +} + +template +static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 4; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m128 acc[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm_loadu_ps(op[mi] + dd + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm_loadu_ps(V + j * d + dd + vi * 4); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 pvec = _mm_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm_storeu_ps(op[mi] + dd + vi * 4, acc[mi][vi]); + } + + for (; dd + 3 < d; dd += 4) + { + __m128 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m128 vvec = _mm_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < d; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m128 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_loadu_ps(optr + dd + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm_storeu_ps(optr + dd + vi * 4, acc[vi]); + } + + for (; dd + 3 < d; dd += 4) + { + __m128 acc = _mm_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), _mm_loadu_ps(V + j * d + dd), acc); + _mm_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; + } + } +} + +template +static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 4; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_loadu_ps(op[mi] + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * D + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm_storeu_ps(op[mi] + vi * 4, acc[vi]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m128 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_loadu_ps(optr + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * D + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm_storeu_ps(optr + vi * 4, acc[vi]); + } +} + +#endif // __SSE2__ + +// Timing instrumentation removed + +static int sdpa_forward_prefill( + const Mat& query_ref, + const Mat& query_bf16_ref, + const Mat& attn_mask_ref, + Mat& key, + Mat& value, + Mat& top_blob, + std::vector& top_blobs, + const Option& opt, + int embed_dim, + int src_seqlen, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int kv_cache, + int attn_mask, + bool use_bf16_path) +{ + (void)num_heads; + (void)query_bf16_ref; + const int BLOCK_M = 64; + const int BLOCK_N = 128; + Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + const bool large_dim = embed_dim > 512 && (src_seqlen > 16 || num_heads_per_group > 1); + const bool pack_q = !large_dim && num_heads_per_group > 1; + Mat q_batch; + if (pack_q) + q_batch.create(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + + if (s_vec.empty() || o_accum.empty() || (pack_q && q_batch.empty())) + return -100; + +#if __AVX512F__ +#if NCNN_BF16 && NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + const bool bf16_init_zero_path = use_bf16_path && out_embed_dim == 4096 && ncnn::cpu_support_x86_avx512_bf16(); +#else + const bool bf16_init_zero_path = false; +#endif + const bool skip_o_accum_zero = out_embed_dim == 4096 && (!use_bf16_path || bf16_init_zero_path); +#else + const bool skip_o_accum_zero = false; +#endif + + int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; + + // Per-head per-M-tile softmax state for cross-N-tile accumulation + Mat m_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + Mat l_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + + if (m_state.empty() || l_state.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) if (num_group * num_m_tiles > 1) + for (int idx = 0; idx < num_group * num_m_tiles; idx++) + { + int g = idx / num_m_tiles; + int m_tile = idx % num_m_tiles; + int m_start = m_tile * BLOCK_M; + int block_m = m_start + BLOCK_M < src_seqlen ? BLOCK_M : src_seqlen - m_start; + + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); + Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); + Mat q_batch_thread; + if (pack_q) + q_batch_thread = q_batch.channel(get_omp_thread_num()); + + // Pre-resolve mask pointers for all heads in this group + const float* mask_data[num_heads_per_group]; + int mask_stride[num_heads_per_group]; + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + mask_data[hq] = nullptr; + mask_stride[hq] = 0; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mh = (maskm.dims == 3 && maskm.c > 1) ? maskm.channel(q) + : (maskm.dims == 3 ? maskm.channel(0) : maskm); + mask_data[hq] = mh; + mask_stride[hq] = mh.w; + } + } + + Mat m_state_tile = m_state.channel(idx); + Mat l_state_tile = l_state.channel(idx); + + Mat query_head_unpacked; + const float* q_data = 0; + if (!large_dim && !pack_q) + { + query_head_unpacked = query_ref.channel(g * num_heads_per_group); + q_data = query_head_unpacked.row(m_start); + } + + if (pack_q) + { + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + float* q_dst = q_batch_thread.row(hq * block_m); + for (int i = 0; i < block_m; i++) + { + memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); + } + } + } + + // Initialize softmax state and zero output accumulator for all Q heads in this group + for (int hq = 0; hq < num_heads_per_group; hq++) + { + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + float* o_ptr = o_accum_thread.row(hq * block_m); + if (!skip_o_accum_zero) + vec_zero(o_ptr, out_embed_dim * block_m); + } + + // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group. + // The large-dim path below has its own N loop over all tiles. + const int n_loop_end = large_dim ? 1 : dst_seqlen; + for (int n_start = 0; n_start < n_loop_end; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + float* s_ptr = s_vec_thread.row(0); + + if (!large_dim) + { + if (pack_q) + q_data = q_batch_thread.row(0); +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(s_ptr, + q_data, + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } + else +#endif + { + qk_gemm_bf16s_avx512_kernel(s_ptr, + q_data, + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } +#elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_ptr, + q_data, + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_ptr, + q_data, + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#else + qk_gemm_bf16s_scalar_kernel(s_ptr, + q_data, + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#endif + } + else +#endif + { + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_specialized_avx512<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#else + qk_gemm_scalar(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#endif + break; + } + } + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + float* s_head = s_ptr + hq * block_m * block_n; + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; + float* sptr = s_head + i * block_n; + decode_mask_vec(sptr, mptr, block_n); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + } + +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim, n_start == 0 && out_embed_dim == 4096); + } + else +#endif + { + pv_gemm_bf16s_avx512_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } +#elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#else + pv_gemm_bf16s_scalar_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#endif + } + else +#endif + { + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + if (n_start == 0) + pv_gemm_avx512<2, 128, true>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + else + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#else + pv_gemm_scalar(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#endif + break; + } + } + } + else + { + // large_dim: N-outer loop, head-inner loop to reuse K/V tiles across heads + for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) + { + int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; + int block_n2 = n_end2 - n_start2; + +#if __AVX512F__ + if (!use_bf16_path && !attn_mask && num_heads_per_group == 1 && embed_dim == 4096 && out_embed_dim == 4096) + { + const Mat query_head = query_ref.channel(g); + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr; + + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec = m_state_tile.row(0); + float* l_vec = l_state_tile.row(0); + float* o_ptr = o_accum_thread.row(0); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * 4096, scale_factors[i], 4096); + } + } + + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + + continue; + } +#endif + + if (num_group * num_m_tiles == 1) + { + #pragma omp parallel for num_threads(opt.num_threads) if (opt.num_threads > 1) + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + const Mat query_bf16_head = query_bf16_ref.channel(q); + qk_gemm_bf16s_avx512bf16_qbf16(s_head, + query_bf16_head.row(m_start), + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else +#endif + { + qk_gemm_bf16s_avx512_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } +#elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_bf16s_scalar_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#endif + } + else +#endif + { + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#endif + break; + } + } + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim, n_start2 == 0 && out_embed_dim == 4096); + } + else +#endif + { + pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } +#elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#else + pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#endif + } + else +#endif + { + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#endif + break; + } + } + } + } + else + { +#if __AVX512F__ + if (!use_bf16_path && !attn_mask && embed_dim == 4096 && out_embed_dim == 4096 && num_heads_per_group >= 4) + { + int hq = 0; + for (; hq + 3 < num_heads_per_group; hq += 4) + { + int q0 = g * num_heads_per_group + hq; + int q1 = q0 + 1; + int q2 = q0 + 2; + int q3 = q0 + 3; + const Mat query_head0 = query_ref.channel(q0); + const Mat query_head1 = query_ref.channel(q1); + const Mat query_head2 = query_ref.channel(q2); + const Mat query_head3 = query_ref.channel(q3); + + const float* q_ptr0 = query_head0.row(m_start); + const float* q_ptr1 = query_head1.row(m_start); + const float* q_ptr2 = query_head2.row(m_start); + const float* q_ptr3 = query_head3.row(m_start); + float* s_head0 = s_ptr + hq * block_m * block_n2; + float* s_head1 = s_head0 + block_m * block_n2; + float* s_head2 = s_head1 + block_m * block_n2; + float* s_head3 = s_head2 + block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head0, q_ptr0, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head1, q_ptr1, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head2, q_ptr2, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head3, q_ptr3, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec0 = m_state_tile.row(hq); + float* l_vec0 = l_state_tile.row(hq); + float* o_ptr0 = o_accum_thread.row(hq * block_m); + float* m_vec1 = m_state_tile.row(hq + 1); + float* l_vec1 = l_state_tile.row(hq + 1); + float* o_ptr1 = o_accum_thread.row((hq + 1) * block_m); + float* m_vec2 = m_state_tile.row(hq + 2); + float* l_vec2 = l_state_tile.row(hq + 2); + float* o_ptr2 = o_accum_thread.row((hq + 2) * block_m); + float* m_vec3 = m_state_tile.row(hq + 3); + float* l_vec3 = l_state_tile.row(hq + 3); + float* o_ptr3 = o_accum_thread.row((hq + 3) * block_m); + + float scale_factors0[BLOCK_M]; + float scale_factors1[BLOCK_M]; + float scale_factors2[BLOCK_M]; + float scale_factors3[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + unsigned char changed2[BLOCK_M]; + unsigned char changed3[BLOCK_M]; + softmax_tile(s_head0, s_head0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n2); + softmax_tile(s_head1, s_head1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n2); + softmax_tile(s_head2, s_head2, m_vec2, l_vec2, scale_factors2, changed2, block_m, block_n2); + softmax_tile(s_head3, s_head3, m_vec3, l_vec3, scale_factors3, changed3, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + { + vec_scale(o_ptr0 + i * out_embed_dim, scale_factors0[i], out_embed_dim); + } + if (changed1[i]) + { + vec_scale(o_ptr1 + i * out_embed_dim, scale_factors1[i], out_embed_dim); + } + if (changed2[i]) + { + vec_scale(o_ptr2 + i * out_embed_dim, scale_factors2[i], out_embed_dim); + } + if (changed3[i]) + { + vec_scale(o_ptr3 + i * out_embed_dim, scale_factors3[i], out_embed_dim); + } + } + + if (block_n2 < 128) + { + if (n_start2 == 0) + pv_gemm_4heads_4096_avx512<2, 32, true>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_4heads_4096_avx512<2, 32>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + } + else + { + if (n_start2 == 0) + pv_gemm_4heads_4096_avx512<4, 16, true>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_4heads_4096_avx512<4, 16>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + } + } + + for (; hq + 1 < num_heads_per_group; hq += 2) + { + int q0 = g * num_heads_per_group + hq; + int q1 = q0 + 1; + const Mat query_head0 = query_ref.channel(q0); + const Mat query_head1 = query_ref.channel(q1); + + const float* q_ptr0 = query_head0.row(m_start); + const float* q_ptr1 = query_head1.row(m_start); + float* s_head0 = s_ptr + hq * block_m * block_n2; + float* s_head1 = s_head0 + block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head0, q_ptr0, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head1, q_ptr1, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec0 = m_state_tile.row(hq); + float* l_vec0 = l_state_tile.row(hq); + float* o_ptr0 = o_accum_thread.row(hq * block_m); + float* m_vec1 = m_state_tile.row(hq + 1); + float* l_vec1 = l_state_tile.row(hq + 1); + float* o_ptr1 = o_accum_thread.row((hq + 1) * block_m); + + float scale_factors0[BLOCK_M]; + float scale_factors1[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + softmax_tile(s_head0, s_head0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n2); + softmax_tile(s_head1, s_head1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + { + vec_scale(o_ptr0 + i * out_embed_dim, scale_factors0[i], out_embed_dim); + } + if (changed1[i]) + { + vec_scale(o_ptr1 + i * out_embed_dim, scale_factors1[i], out_embed_dim); + } + } + + if (n_start2 == 0) + pv_gemm_2heads_4096_avx512<4, 32, true>(o_ptr0, o_ptr1, s_head0, s_head1, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_2heads_4096_avx512<4, 32>(o_ptr0, o_ptr1, s_head0, s_head1, value_head.row(n_start2), block_m, block_n2); + } + + for (; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + } + + continue; + } +#endif + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + const Mat query_bf16_head = query_bf16_ref.channel(q); + qk_gemm_bf16s_avx512bf16_qbf16(s_head, + query_bf16_head.row(m_start), + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else +#endif + { + qk_gemm_bf16s_avx512_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } +#elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_bf16s_scalar_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#endif + } + else +#endif + { + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#endif + break; + } + } + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + +#if NCNN_BF16 + if (use_bf16_path) + { +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim, n_start2 == 0 && out_embed_dim == 4096); + } + else +#endif + { + pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } +#elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#else + pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#endif + } + else +#endif + { + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#endif + break; + } + } + } + } + } + } + } + + // Normalize all Q heads for this M tile and write back to top_blob + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + Mat top_blob_head = top_blob.channel(q); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + for (int i = 0; i < block_m; i++) + { + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + { + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, o_ptr + i * out_embed_dim + k), vinv_l)); + k = out_embed_dim; + } +#elif __AVX__ + __m256 vinv_l = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + { + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } +#elif __SSE2__ + __m128 vinv_l = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + { + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } +#endif + for (; k < out_embed_dim; k++) + { + outptr[k] = o_ptr[i * out_embed_dim + k] * inv_l; + } + } + } + } + + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; +} + +static int sdpa_quantize_key_value_int8_x86(const Mat& key, const Mat& value, + Mat& key_int8, Mat& key_scales, Mat& value_int8, Mat& value_scales, + int num_group, int dst_seqlen, int embed_dim, int out_embed_dim, int v_num_blocks, + bool cache_valid, int past_seqlen, bool kv_cache, + const Option& opt) +{ + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + Mat key_int8_head = key_int8.channel(g); + Mat key_scales_head = key_scales.channel(g); + int j_start = cache_valid ? past_seqlen : 0; + for (int j = j_start; j < dst_seqlen; j++) + { + dynamic_quantize_rowwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + key_scales_head.row(j)[0] = 1.f / key_scales_head.row(j)[0]; + } + + if (kv_cache) + { + const Mat value_head = value.channel(g); + Mat value_int8_head = value_int8.channel(g); + Mat value_scales_head = value_scales.channel(g); + for (int j = j_start; j < dst_seqlen; j++) + { + dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); + for (int vb = 0; vb < v_num_blocks; vb++) + { + value_scales_head.row(j)[vb] = 1.f / value_scales_head.row(j)[vb]; + } + } + } + } + return 0; +} + +static int sdpa_decode_int8_x86( + const Mat& query, + const Mat& key_int8, + const Mat& key_scales, + const Mat& value_int8, + const Mat& value_scales, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask, + Mat& o_accum, + Mat& s_vec, + Mat& q_int8_tile, + Mat& q_scales_tile) +{ + const int BLOCK_N = 128; + // Decode path with dedicated int8 GEMV kernels + // For GQA/MQA, group-parallel reduces KV cache contention + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_int8_head = key_int8.channel(g); + const Mat key_scales_head = key_scales.channel(g); + const Mat value_int8_head = value_int8.channel(g); + const Mat value_scales_head = value_scales.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale(outptr, inv_l, out_embed_dim); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale(outptr, inv_l, out_embed_dim); + } + } + + return 0; +} + +static int sdpa_prefill_int8_x86( + const Mat& query, + const Mat& key_int8, + const Mat& key_scales, + const Mat& value_int8, + const Mat& value_scales, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int out_embed_dim, + int src_seqlen, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask, + int kv_cache, + Mat& o_accum, + Mat& s_vec, + Mat& p_vec, + Mat& q_int8_tile, + Mat& q_scales_tile, + const Mat& value) +{ + const int BLOCK_M = 64; + const int BLOCK_N = 128; +#if __AVX512F__ + const bool skip_o_accum_zero = !kv_cache && out_embed_dim == 4096; +#else + const bool skip_o_accum_zero = false; +#endif + +#if __AVX512F__ + if (!kv_cache && !attn_mask && out_embed_dim == 4096 && num_heads_per_group >= 2) + { + const int num_group = num_heads / num_heads_per_group; + const int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; + const int num_head_pairs = (num_heads_per_group + 1) / 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_group * num_m_tiles * num_head_pairs; task++) + { + int pair_idx = task % num_head_pairs; + int tmp = task / num_head_pairs; + int m_tile = tmp % num_m_tiles; + int g = tmp / num_m_tiles; + + int hq0 = pair_idx * 2; + int hq1 = hq0 + 1; + int q0 = g * num_heads_per_group + hq0; + int q1 = q0 + 1; + bool has_q1 = hq1 < num_heads_per_group; + + const Mat query_head0 = query.channel(q0); + const Mat query_head1 = has_q1 ? query.channel(q1) : Mat(); + const Mat key_int8_head = key_int8.channel(g); + const Mat key_scales_head = key_scales.channel(g); + const Mat value_head = value.channel(g); + Mat top_blob_head0 = top_blob.channel(q0); + Mat top_blob_head1 = has_q1 ? top_blob.channel(q1) : Mat(); + + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr0 = s_vec.row(get_omp_thread_num()); + float* s_vec_ptr1 = s_vec_ptr0 + BLOCK_M * BLOCK_N; + float* p_vec_ptr0 = p_vec.row(get_omp_thread_num()); + float* p_vec_ptr1 = p_vec_ptr0 + BLOCK_M * BLOCK_N; + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); + + int m_start = m_tile * BLOCK_M; + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; + + for (int i = 0; i < block_m; i++) + { + dynamic_quantize_rowwise(query_head0.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; + + if (has_q1) + { + dynamic_quantize_rowwise(query_head1.row(m_start + i), q_int8_tile_head.row(BLOCK_M + i), q_scales_tile_head.row(BLOCK_M + i), embed_dim); + q_scales_tile_head.row(BLOCK_M + i)[0] = 1.f / q_scales_tile_head.row(BLOCK_M + i)[0]; + } + } - qk_gemm = 0; - qkv_gemm = 0; - qk_softmax = 0; -} + float m_vec0[BLOCK_M]; + float l_vec0[BLOCK_M]; + float m_vec1[BLOCK_M]; + float l_vec1[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_vec0[i] = -FLT_MAX; + l_vec0[i] = 0.f; + m_vec1[i] = -FLT_MAX; + l_vec1[i] = 0.f; + } -int SDPA_x86::create_pipeline(const Option& _opt) -{ - Option opt = _opt; - if (int8_scale_term) - { - opt.use_packing_layout = false; // TODO enable packing - support_bf16_storage = false; + float* o_ptr0 = o_accum_head.row(0); + float* o_ptr1 = o_accum_head.row(BLOCK_M); + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + qk_int8_gemm_tiled(s_vec_ptr0, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + + if (has_q1) + { + qk_int8_gemm_tiled(s_vec_ptr1, + q_int8_tile_head.row(BLOCK_M), + key_int8_head.row(n_start), + q_scales_tile_head.row(BLOCK_M), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } + + float scale_factors0[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + softmax_tile(p_vec_ptr0, s_vec_ptr0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n); + + float scale_factors1[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + if (has_q1) + softmax_tile(p_vec_ptr1, s_vec_ptr1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + vec_scale(o_ptr0 + i * 4096, scale_factors0[i], 4096); + if (has_q1 && changed1[i]) + vec_scale(o_ptr1 + i * 4096, scale_factors1[i], 4096); + } + + if (has_q1) + { + if (n_start == 0) + pv_gemm_2heads_4096_avx512<16, 16, true>(o_ptr0, o_ptr1, p_vec_ptr0, p_vec_ptr1, value_head.row(n_start), block_m, block_n); + else + pv_gemm_2heads_4096_avx512<8, 16>(o_ptr0, o_ptr1, p_vec_ptr0, p_vec_ptr1, value_head.row(n_start), block_m, block_n); + } + else + { + if (n_start == 0) + pv_gemm_avx512<4, 64, true>(o_ptr0, p_vec_ptr0, value_head.row(n_start), block_m, block_n, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr0, p_vec_ptr0, value_head.row(n_start), block_m, block_n, 4096); + } + } + + for (int i = 0; i < block_m; i++) + { + const float* optr0 = o_ptr0 + i * 4096; + float* outptr0 = top_blob_head0.row(m_start + i); + __m512 vinv_l0 = _mm512_set1_ps(1.f / l_vec0[i]); + for (int k = 0; k < 4096; k += 16) + _mm512_storeu_ps(outptr0 + k, _mm512_mul_ps(_mm512_loadu_ps(optr0 + k), vinv_l0)); + + if (has_q1) + { + const float* optr1 = o_ptr1 + i * 4096; + float* outptr1 = top_blob_head1.row(m_start + i); + __m512 vinv_l1 = _mm512_set1_ps(1.f / l_vec1[i]); + for (int k = 0; k < 4096; k += 16) + _mm512_storeu_ps(outptr1 + k, _mm512_mul_ps(_mm512_loadu_ps(optr1 + k), vinv_l1)); + } + } + } + + return 0; } +#endif + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); - ncnn::ParamDict pd; - pd.set(0, -1); // axis - pd.set(1, 1); - qk_softmax->load_param(pd); - qk_softmax->load_model(ModelBinFromMatArray(0)); - qk_softmax->create_pipeline(opt); - } - - // Q * K^T - if (scale != 0.f) - { - qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - - pd.set(0, scale); // alpha - pd.set(1, 1.f / scale); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 -#if NCNN_INT8 - pd.set(18, int8_scale_term); -#endif - qk_gemm->load_param(pd); - qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; - opt1.num_threads = 1; - qk_gemm->create_pipeline(opt1); - } - - // Attn * V - { - qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(0, 1.f); // alpha - pd.set(1, 1.f); // beta - pd.set(2, 0); // transA (Attn: Seq x Seq) - pd.set(3, 0); // transB (V: Seq x Embed) => Attn * V - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 - pd.set(14, 0); // output_transpose -#if NCNN_INT8 - pd.set(18, int8_scale_term); + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + Mat mask_head; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + } + + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); + + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + { + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; + + for (int i = 0; i < block_m; i++) + { + dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + if (!skip_o_accum_zero) + vec_zero(optr, out_embed_dim); + } + + float m_vec[BLOCK_M]; + float l_vec[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + if (block_m == 1) + { + qk_int8_gemm_row(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0)[0], + key_scales_head.row(n_start), + block_n, embed_dim, _scale); + } + else + { + qk_int8_gemm_tiled(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } + + if (attn_mask) + { + for (int i = 0; i < block_m; i++) + { + float* s_row = s_vec_ptr + i * block_n; + decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); + } + } + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(p_vec_ptr, s_vec_ptr, m_vec, l_vec, scale_factors, changed, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + vec_scale(o_accum_head.row(i), scale_factors[i], out_embed_dim); + } + + if (kv_cache) + { + pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, + value_int8_head.row(n_start), + value_scales_head.row(n_start), + block_m, block_n, out_embed_dim); + } + else + { + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; #endif - qkv_gemm->load_param(pd); - qkv_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; - opt1.num_threads = 1; - qkv_gemm->create_pipeline(opt1); + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + if (n_start == 0) + pv_gemm_avx512<4, 64, true>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + else + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#else + pv_gemm_scalar(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#endif + break; + } + } + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); + } +#elif __AVX__ + __m256 vinv_l256 = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); +#elif __SSE2__ + __m128 vinv_l128 = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); +#endif + for (; k < out_embed_dim; k++) + outptr[k] = optr[k] * inv_l; + } + } } return 0; } -int SDPA_x86::destroy_pipeline(const Option& _opt) +static int sdpa_decode_bf16s_x86( + const Mat& query_ref, + const Mat& key, + const Mat& value, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask) { - Option opt = _opt; - if (int8_scale_term) + const int BLOCK_N = 128; + const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; + if (use_split_kv) { - opt.use_packing_layout = false; // TODO enable packing - } + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + + const float* qptr = query_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float* p = partials.channel(q).row(chunk); + float* out = p + 2; + float* m_out = p; + float* l_out = p + 1; + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(out, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n = n_start; n < n_end; n += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n_end - n); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(out, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#endif - if (qk_softmax) + m = new_m; + } + + *m_out = m; + *l_out = l; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce_bf16s(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } + } + else { - qk_softmax->destroy_pipeline(opt); - delete qk_softmax; - qk_softmax = 0; + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(outptr, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(outptr, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#endif + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale(outptr, inv_l, out_embed_dim); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(outptr, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(outptr, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#endif + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale(outptr, inv_l, out_embed_dim); + } + } } - if (qk_gemm) + return 0; +} + +static int sdpa_decode_x86( + const Mat& query_ref, + const Mat& key, + const Mat& value, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask) +{ + const int BLOCK_N = 128; + const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; + if (use_split_kv) { - qk_gemm->destroy_pipeline(opt); - delete qk_gemm; - qk_gemm = 0; - } + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + + const float* qptr = query_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } - if (qkv_gemm) + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } + } + else { - qkv_gemm->destroy_pipeline(opt); - delete qkv_gemm; - qkv_gemm = 0; + // Decode path: fused GEMV kernel for single-query attention + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } } return 0; @@ -204,146 +6249,215 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } const int num_heads_per_group = num_heads / num_group; + const float _scale = scale == 0.f ? 1.f / sqrtf((float)embed_dim) : scale; + + const int BLOCK_M = 64; + const int BLOCK_N = 128; - Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); - if (qk_cross.empty()) + Mat& top_blob = top_blobs[0]; + top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); + if (top_blob.empty()) return -100; - std::vector retqks(num_heads); - - // Dynamic Scale Calculation and Beta Correction - Layer* _qk_gemm = qk_gemm; - if (scale == 0.f) - { - float _scale = 1.f / sqrt(embed_dim); - - _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - - pd.set(0, _scale); // alpha - pd.set(1, 1.f / _scale); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 -#if NCNN_INT8 - pd.set(18, int8_scale_term); +#if NCNN_BF16 + bool use_bf16_path = opt.use_bf16_storage && query.elembits() == 16; + bool use_qbf16_prefill = false; +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + use_qbf16_prefill = use_bf16_path && src_seqlen > 1 && embed_dim > 512 + && (src_seqlen > 16 || num_heads_per_group > 1) + && ncnn::cpu_support_x86_avx512_bf16(); +#endif +#endif + Mat query_fp32; + Mat attn_mask_fp32; + if (use_bf16_path) + { + if (!use_qbf16_prefill) + { + cast_bfloat16_to_float32(query, query_fp32, opt); + if (query_fp32.empty()) + return -100; + } + if (attn_mask && !attn_mask_blob.empty()) + { + cast_bfloat16_to_float32(attn_mask_blob, attn_mask_fp32, opt); + if (attn_mask_fp32.empty()) + return -100; + } + } + const Mat& query_ref = use_bf16_path && !use_qbf16_prefill ? query_fp32 : query; + const Mat& attn_mask_ref = (use_bf16_path && attn_mask) ? attn_mask_fp32 : attn_mask_blob; +#else + const Mat& query_ref = query; + const Mat& attn_mask_ref = attn_mask_blob; + (void)query_ref; + (void)attn_mask_ref; #endif - _qk_gemm->load_param(pd); - _qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; - opt1.num_threads = 1; - _qk_gemm->create_pipeline(opt1); +#if NCNN_INT8 + bool use_int8_path = int8_scale_term; + if (use_int8_path && src_seqlen == 1) + { + // Adaptive threshold: int8 decode is faster only when compute is large enough + // to amortize quantization overhead. For small problems, fall back to fp32. + if (embed_dim <= 128) + use_int8_path = (dst_seqlen >= 64); + else if (embed_dim <= 512) + use_int8_path = (dst_seqlen >= 512); + else + use_int8_path = (dst_seqlen >= 32); } - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_heads; i++) + if (use_int8_path) { - // 1. Q * K^T - std::vector qk_bottom_blobs; - qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] - qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] + const int qk_num_blocks = (embed_dim + 31) / 32; + (void)qk_num_blocks; + const int v_num_blocks = (out_embed_dim + 31) / 32; - if (attn_mask) + Mat key_int8(embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); + Mat key_scales(1, dst_seqlen, num_group, 4u, opt.blob_allocator); + Mat value_int8(out_embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); + Mat value_scales(v_num_blocks, dst_seqlen, num_group, 4u, opt.blob_allocator); + + if (key_int8.empty() || key_scales.empty() || value_int8.empty() || value_scales.empty()) + return -100; + + bool use_kv_cache = kv_cache && past_seqlen > 0; + bool cache_valid = false; + if (use_kv_cache + && cached_kv_seqlen == past_seqlen + && cached_num_group == num_group + && cached_embed_dim == embed_dim + && cached_out_embed_dim == out_embed_dim + && !cached_key_int8.empty() + && !cached_key_scales.empty() + && !cached_value_int8.empty() + && !cached_value_scales.empty()) { - // Ensure mask is 2D for Gemm auto-broadcast detection - Mat maskm = attn_mask_blob; - if (maskm.dims == 3) + cache_valid = true; + for (int g = 0; g < num_group; g++) { - // If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast) - maskm = maskm.channel(maskm.c > 1 ? i : 0); + memcpy(key_int8.channel(g), cached_key_int8.channel(g), embed_dim * past_seqlen); + memcpy(key_scales.channel(g), cached_key_scales.channel(g), past_seqlen * sizeof(float)); + memcpy(value_int8.channel(g), cached_value_int8.channel(g), out_embed_dim * past_seqlen); + memcpy(value_scales.channel(g), cached_value_scales.channel(g), v_num_blocks * past_seqlen * sizeof(float)); } - qk_bottom_blobs.push_back(maskm); } - std::vector qk_top_blobs(1); - qk_top_blobs[0] = qk_cross.channel(i); + sdpa_quantize_key_value_int8_x86(key, value, key_int8, key_scales, value_int8, value_scales, + num_group, dst_seqlen, embed_dim, out_embed_dim, v_num_blocks, + cache_valid, past_seqlen, kv_cache, opt); - Option opt1 = opt; - opt1.num_threads = 1; - opt1.blob_allocator = qk_cross.allocator; - retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); - } + const bool use_int8_prefill_2heads_workspace = src_seqlen > 1 && !kv_cache && !attn_mask && out_embed_dim == 4096 && num_heads_per_group >= 2; + const int int8_prefill_workspace_heads = use_int8_prefill_2heads_workspace ? 2 : 1; - if (scale == 0.f) - { - Option opt1 = opt; - opt1.num_threads = 1; - _qk_gemm->destroy_pipeline(opt1); + Mat o_accum(out_embed_dim, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat q_int8_tile(embed_dim, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 1u, opt.workspace_allocator); + Mat q_scales_tile(1, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); - delete _qk_gemm; - _qk_gemm = 0; - } + if (o_accum.empty() || s_vec.empty() || p_vec.empty() || q_int8_tile.empty() || q_scales_tile.empty()) + return -100; - for (int i = 0; i < num_heads; i++) - { - if (retqks[i] != 0) - return retqks[i]; - } + if (src_seqlen == 1) + { + int ret = sdpa_decode_int8_x86(query, key_int8, key_scales, value_int8, value_scales, + top_blob, attn_mask_ref, opt, embed_dim, num_heads, num_group, out_embed_dim, + dst_seqlen, num_heads_per_group, _scale, attn_mask, + o_accum, s_vec, q_int8_tile, q_scales_tile); + if (ret != 0) + return ret; - // 2. Softmax - int retqk = qk_softmax->forward_inplace(qk_cross, opt); - if (retqk != 0) - return retqk; + if (kv_cache && dst_seqlen > 0) + { + cached_key_int8 = key_int8; + cached_key_scales = key_scales; + cached_value_int8 = value_int8; + cached_value_scales = value_scales; + cached_kv_seqlen = dst_seqlen; + cached_num_group = num_group; + cached_embed_dim = embed_dim; + cached_out_embed_dim = out_embed_dim; + } - Mat value_fp32 = value; -#if NCNN_BF16 - if (opt.use_bf16_storage && value.elembits() == 16) - { - // qkv_gemm need fp32 inputs - cast_bfloat16_to_float32(value, value_fp32, opt); - if (value_fp32.empty()) - return -100; - } -#endif + return 0; + } + else + { + int ret = sdpa_prefill_int8_x86(query, key_int8, key_scales, value_int8, value_scales, + top_blob, attn_mask_ref, opt, embed_dim, num_heads, out_embed_dim, src_seqlen, + dst_seqlen, num_heads_per_group, _scale, attn_mask, kv_cache, + o_accum, s_vec, p_vec, q_int8_tile, q_scales_tile, value); + if (ret != 0) + return ret; + } - Mat& top_blob = top_blobs[0]; - top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } - // 3. Attn * V - std::vector retqkvs(num_heads); + if (kv_cache && dst_seqlen > 0) + { + cached_key_int8 = key_int8; + cached_key_scales = key_scales; + cached_value_int8 = value_int8; + cached_value_scales = value_scales; + cached_kv_seqlen = dst_seqlen; + cached_num_group = num_group; + cached_embed_dim = embed_dim; + cached_out_embed_dim = out_embed_dim; + } - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_heads; i++) + return 0; + } +#endif // NCNN_INT8 + + // FP32 optimized path using tiled GEMM + online softmax + if (src_seqlen == 1) { - std::vector qkv_bottom_blobs(2); - qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] - qkv_bottom_blobs[1] = value_fp32.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] +#if NCNN_BF16 + if (use_bf16_path) + { + int ret = sdpa_decode_bf16s_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); + if (ret != 0) + return ret; - std::vector qkv_top_blobs(1); - qkv_top_blobs[0] = top_blob.channel(i); // Output + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } - Option opt1 = opt; - opt1.num_threads = 1; - retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); - } + return 0; + } +#endif // NCNN_BF16 - for (int i = 0; i < num_heads; i++) - { - if (retqkvs[i] != 0) - return retqkvs[i]; - } + int ret = sdpa_decode_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); + if (ret != 0) + return ret; - value_fp32.release(); + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } - if (kv_cache) - { - top_blobs[1] = key; - top_blobs[2] = value; + return 0; } - return 0; + return sdpa_forward_prefill(query_ref, query, attn_mask_ref, key, value, top_blob, + top_blobs, opt, embed_dim, src_seqlen, num_heads, + num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, kv_cache, attn_mask, + use_bf16_path); } } // namespace ncnn diff --git a/src/layer/x86/sdpa_x86.h b/src/layer/x86/sdpa_x86.h index 1b8f28e11f66..5e10432273e7 100644 --- a/src/layer/x86/sdpa_x86.h +++ b/src/layer/x86/sdpa_x86.h @@ -18,11 +18,21 @@ class SDPA_x86 : public SDPA virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; -public: - Layer* qk_gemm; - Layer* qkv_gemm; - - Layer* qk_softmax; +private: +#if NCNN_BF16 + mutable Mat cached_key_bf16; + mutable Mat cached_value_bf16; +#endif +#if NCNN_INT8 + mutable Mat cached_key_int8; + mutable Mat cached_key_scales; + mutable Mat cached_value_int8; + mutable Mat cached_value_scales; +#endif + mutable int cached_kv_seqlen; + mutable int cached_num_group; + mutable int cached_embed_dim; + mutable int cached_out_embed_dim; }; } // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx2.cpp b/src/layer/x86/sdpa_x86_avx2.cpp new file mode 100644 index 000000000000..7a80ae4ed765 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx2.cpp @@ -0,0 +1,74 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVX2__ + +void dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width) +{ + dynamic_quantize_blockwise_avx2_kernel(src, dst, scales, width); +} + +void dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width) +{ + dynamic_quantize_rowwise_avx2_kernel(src, dst, scale, width); +} + +int qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len) +{ + return qk_int8_dot_block_avx2_kernel(a, b, len); +} + +void decode_qk_dot_int8_avx2(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avx2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avx2(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avx2(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +void decode_pv_gemv_int8_avx2(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) +{ + decode_pv_gemv_int8_avx2_kernel(out, s, V, vscales, n_start, block_n, out_d); +} + +void pv_float_int8_gemm_row_avx2(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) +{ + pv_float_int8_gemm_row_avx2_kernel(out, p_row, V, vscales, n, out_d); +} + +void pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len) +{ + pv_float_int8_fma_block_avx2_kernel(out, p_invscale, v, len); +} + +void pv_float_int8_gemm_tile_avx2(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim) +{ + pv_float_int8_gemm_tile_avx2_kernel(O, P, V, vscales, block_m, block_n, out_embed_dim); +} + +#endif // __AVX2__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx512bf16.cpp b/src/layer/x86/sdpa_x86_avx512bf16.cpp new file mode 100644 index 000000000000..2379c43e92bd --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx512bf16.cpp @@ -0,0 +1,110 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_bf16s.h" + +#if __AVX512BF16__ + +void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_bf16s_avx512bf16_kernel(s, q, K, n_start, block_n, d, scale); +} + +void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + decode_pv_gemv_bf16s_avx512bf16_kernel(out, s, V, n_start, block_n, out_d); +} + +// Explicit instantiations for common embed_dim values +// These allow GCC to fully unroll the inner d-loops at compile time, +// matching the performance of the fp32 template-specialized path. +template void qk_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int, float); + +template void pv_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int, bool); + +void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) +{ + switch (d) + { + case 64: + qk_gemm_bf16s_avx512bf16_kernel_t<64>(S, Q, K, m, n, scale); + break; + case 128: + qk_gemm_bf16s_avx512bf16_kernel_t<128>(S, Q, K, m, n, scale); + break; + case 256: + qk_gemm_bf16s_avx512bf16_kernel_t<256>(S, Q, K, m, n, scale); + break; + case 512: + qk_gemm_bf16s_avx512bf16_kernel_t<512>(S, Q, K, m, n, scale); + break; + case 1024: + qk_gemm_bf16s_avx512bf16_kernel_t<1024>(S, Q, K, m, n, scale); + break; + case 4096: + qk_gemm_bf16s_avx512bf16_kernel_t<4096>(S, Q, K, m, n, scale); + break; + default: + qk_gemm_bf16s_avx512bf16_kernel(S, Q, K, m, n, d, scale); + break; + } +} + +void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const unsigned short* K, int m, int n, int d, float scale) +{ + qk_gemm_bf16s_avx512bf16_qbf16_kernel(S, Q, K, m, n, d, scale); +} + +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d, bool init_zero) +{ + switch (d) + { + case 64: + pv_gemm_bf16s_avx512bf16_kernel_t<64>(O, P, V, m, n, init_zero); + break; + case 128: + pv_gemm_bf16s_avx512bf16_kernel_t<128>(O, P, V, m, n, init_zero); + break; + case 256: + pv_gemm_bf16s_avx512bf16_kernel_t<256>(O, P, V, m, n, init_zero); + break; + case 512: + pv_gemm_bf16s_avx512bf16_kernel_t<512>(O, P, V, m, n, init_zero); + break; + case 1024: + pv_gemm_bf16s_avx512bf16_kernel_t<1024>(O, P, V, m, n, init_zero); + break; + case 4096: + pv_gemm_bf16s_avx512bf16_kernel_t<4096>(O, P, V, m, n, init_zero); + break; + default: + pv_gemm_bf16s_avx512bf16_kernel(O, P, V, m, n, d); + break; + } +} + +#endif // __AVX512BF16__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx512vnni.cpp b/src/layer/x86/sdpa_x86_avx512vnni.cpp new file mode 100644 index 000000000000..2719d668ed73 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx512vnni.cpp @@ -0,0 +1,42 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVX512F__ && __AVX512VNNI__ + +void decode_qk_dot_int8_avx512vnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avx512_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avx512vnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx512vnni_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avx512vnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx512vnni_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVX512F__ && __AVX512VNNI__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avxvnni.cpp b/src/layer/x86/sdpa_x86_avxvnni.cpp new file mode 100644 index 000000000000..296cbc89a7e0 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avxvnni.cpp @@ -0,0 +1,41 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVXVNNI__ + +void decode_qk_dot_int8_avxvnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avxvnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + // fallback to avx2 kernel for row-wise gemm + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avxvnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + // fallback to avx2 kernel for tiled gemm + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVXVNNI__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avxvnniint8.cpp b/src/layer/x86/sdpa_x86_avxvnniint8.cpp new file mode 100644 index 000000000000..f858dceb08c5 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avxvnniint8.cpp @@ -0,0 +1,39 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVXVNNIINT8__ + +void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avxvnniint8(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avxvnniint8(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVXVNNIINT8__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h new file mode 100644 index 000000000000..52721356c8bb --- /dev/null +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -0,0 +1,2462 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef SDPA_X86_BF16S_H +#define SDPA_X86_BF16S_H + +#include + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale); +void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d); +void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale); +void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const unsigned short* K, int m, int n, int d, float scale); +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d, bool init_zero = false); +#endif + +// --------------------------------------------------------------------------- +// decode_qk_dot_bf16s : Q(fp32) dot K(bf16) -> S(fp32) +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static inline void decode_qk_dot_bf16s_avx512_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + if (d >= 256) + { + for (; j + 7 < block_n; j += 8) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + const unsigned short* k2 = K + (n_start + j + 2) * d; + const unsigned short* k3 = K + (n_start + j + 3) * d; + const unsigned short* k4 = K + (n_start + j + 4) * d; + const unsigned short* k5 = K + (n_start + j + 5) * d; + const unsigned short* k6 = K + (n_start + j + 6) * d; + const unsigned short* k7 = K + (n_start + j + 7) * d; + + if (j + 15 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 8) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 9) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 10) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 11) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 12) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 13) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 14) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 15) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + __m512 acc4 = _mm512_setzero_ps(); + __m512 acc5 = _mm512_setzero_ps(); + __m512 acc6 = _mm512_setzero_ps(); + __m512 acc7 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + acc4 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k4 + k))), acc4); + acc5 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k5 + k))), acc5); + acc6 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k6 + k))), acc6); + acc7 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k7 + k))), acc7); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + acc4 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k4 + k)), acc4); + acc5 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k5 + k)), acc5); + acc6 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k6 + k)), acc6); + acc7 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k7 + k)), acc7); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + s[j + 4] = _mm512_comp_reduce_add_ps(acc4) * scale; + s[j + 5] = _mm512_comp_reduce_add_ps(acc5) * scale; + s[j + 6] = _mm512_comp_reduce_add_ps(acc6) * scale; + s[j + 7] = _mm512_comp_reduce_add_ps(acc7) * scale; + } + } + + for (; j + 3 < block_n; j += 4) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + const unsigned short* k2 = K + (n_start + j + 2) * d; + const unsigned short* k3 = K + (n_start + j + 3) * d; + + if (j + 7 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 5) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 6) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 7) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m512 acc = _mm512_setzero_ps(); + int k = 0; + const unsigned short* kptr = K + (n_start + j) * d; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +} +#endif // __AVX512F__ + +#if __AVX__ +static inline void decode_qk_dot_bf16s_avx_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + + if (j + 5 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 3) * d), _MM_HINT_T1); + } + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + sum0 += q[k] * bfloat16_to_float32(k0[k]); + sum1 += q[k] * bfloat16_to_float32(k1[k]); + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + + for (; j < block_n; j++) + { + if (j + 2 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + + const unsigned short* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); + int k = 0; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} +#endif // __AVX__ + +#if __SSE2__ +static inline void decode_qk_dot_bf16s_sse2_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m128 acc = _mm_setzero_ps(); + int k = 0; + const unsigned short* kptr = K + (n_start + j) * d; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr + k))), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} +#endif // __SSE2__ + +static inline void decode_qk_dot_bf16s_scalar_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + for (int j = 0; j < block_n; j++) + { + float sum = 0.f; + const unsigned short* kptr = K + (n_start + j) * d; + for (int k = 0; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} + +// --------------------------------------------------------------------------- +// decode_pv_gemv_bf16s : S(fp32) gemv V(bf16) -> out(fp32) +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static inline void decode_pv_gemv_bf16s_avx512_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + if (j + 6 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 6) * out_d), _MM_HINT_T1); + + __m512 pvec0 = _mm512_set1_ps(s[j]); + __m512 pvec1 = _mm512_set1_ps(s[j + 1]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v00 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v01 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k + 16))); + __m512 v10 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k))); + __m512 v11 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(pvec0, v00, oval0); + oval1 = _mm512_fmadd_ps(pvec0, v01, oval1); + oval0 = _mm512_fmadd_ps(pvec1, v10, oval0); + oval1 = _mm512_fmadd_ps(pvec1, v11, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k))); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_storeu_ps(out + k, oval); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 v0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j) * out_d + k)); + __m512 v1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j + 1) * out_d + k)); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_mask_storeu_ps(out + k, mask_d, oval); + } + } + for (; j < block_n; j++) + { + __m512 pvec512 = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(pvec512, v0, oval0); + oval1 = _mm512_fmadd_ps(pvec512, v1, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec512, vval, oval)); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j) * out_d + k)); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); + } + } +} +#endif // __AVX512F__ + +#if __AVX__ +static inline void decode_pv_gemv_bf16s_avx_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int j = 0; + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + + int k = 0; + __m256 pvec256 = _mm256_set1_ps(s[j]); + for (; k + 7 < out_d; k += 8) + { + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + (n_start + j) * out_d + k))); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec256, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * bfloat16_to_float32(V[(n_start + j) * out_d + k]); + } +} +#endif // __AVX__ + +#if __SSE2__ +static inline void decode_pv_gemv_bf16s_sse2_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + + __m128 pvec128 = _mm_set1_ps(s[j]); + int k = 0; + for (; k + 3 < out_d; k += 4) + { + __m128 oval = _mm_loadu_ps(out + k); + __m128 vval = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k))); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec128, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * bfloat16_to_float32(V[(n_start + j) * out_d + k]); + } +} +#endif // __SSE2__ + +static inline void decode_pv_gemv_bf16s_scalar_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + float p = s[j]; + const unsigned short* vptr = V + (n_start + j) * out_d; + for (int k = 0; k < out_d; k++) + out[k] += p * bfloat16_to_float32(vptr[k]); + } +} + +// --------------------------------------------------------------------------- +// sdpa_decode_reduce_bf16s : reduce partial results from split-kv chunks +// --------------------------------------------------------------------------- + +static inline void sdpa_decode_reduce_bf16s( + float* out, int out_d, + const float* partials, int num_chunks, int partial_stride) +{ + float M_final = -FLT_MAX; + float S_final = 0.f; + // vec_zero + { +#if __AVX512F__ + __m512 zero512 = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, zero512); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, zero512); + } +#else + int i = 0; +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, zero256); +#endif +#if __SSE2__ + __m128 zero128 = _mm_setzero_ps(); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, zero128); +#endif + for (; i < out_d; i++) + out[i] = 0.f; +#endif + } + + for (int c = 0; c < num_chunks; c++) + { + const float* p = partials + c * partial_stride; + float M_chunk = p[0]; + float S_chunk = p[1]; + if (S_chunk == 0.f) continue; + + const float* VKQ_chunk = p + 2; + + float M_new = std::max(M_final, M_chunk); + float scale_final = expf(M_final - M_new); + float scale_chunk = expf(M_chunk - M_new); + + for (int k = 0; k < out_d; k++) + { + out[k] = out[k] * scale_final + VKQ_chunk[k] * scale_chunk; + } + + S_final = S_final * scale_final + S_chunk * scale_chunk; + M_final = M_new; + } + + if (S_final != 0.f) + { + float inv_s = 1.f / S_final; + // vec_scale + { +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(inv_s); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); + } +#else + int i = 0; +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(inv_s); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); +#endif +#if __SSE2__ + __m128 vscale128 = _mm_set1_ps(inv_s); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); +#endif + for (; i < out_d; i++) + out[i] *= inv_s; +#endif + } + } +} + +// --------------------------------------------------------------------------- +// qk_gemm_bf16s : Q(fp32) x K^T(bf16) -> S(fp32) [prefill] +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static void qk_gemm_bf16s_avx512_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))); + __m512 kv1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)); + __m512 kv1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))); + __m512 kv1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)); + __m512 kv1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + const unsigned short* k2 = K + (j + 2) * d; + const unsigned short* k3 = K + (j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + int k = 0; + __m512 vacc = _mm512_setzero_ps(); + for (; k + 15 < d; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))), vacc); + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)), vacc); + } + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} +#endif // __AVX512F__ + +#if __AVX__ +static void qk_gemm_bf16s_avx_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kv0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))); + __m256 kv1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum0 = _mm256_reduce_add_ps(acc[mi][0]); + float sum1 = _mm256_reduce_add_ps(acc[mi][1]); + for (; k < d; k++) + { + float qv = Q[(i + mi) * d + k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kvec = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum = _mm256_reduce_add_ps(acc[mi]); + for (; k < d; k++) + sum += Q[(i + mi) * d + k] * bfloat16_to_float32(kptr[k]); + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} +#endif // __AVX__ + +#if __SSE2__ +static void qk_gemm_bf16s_sse2_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(k0 + k))), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(k1 + k))), acc1); + } + float sum0 = _mm_reduce_add_ps(acc0); + float sum1 = _mm_reduce_add_ps(acc1); + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + __m128 acc = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr + k))), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} +#endif // __SSE2__ + +static void qk_gemm_bf16s_scalar_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + for (int j = 0; j < n; j++) + { + const unsigned short* kptr = K + j * d; + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} + +// --------------------------------------------------------------------------- +// pv_gemm_bf16s : P(fp32) x V(bf16) -> O(fp32) [prefill] +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static void pv_gemm_bf16s_avx512_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 127 < d; dd += 128) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + const float* pptr[4]; + for (int mi = 0; mi < 4; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[4][8]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + 0 * 16); + acc[mi][1] = _mm512_loadu_ps(op[mi] + 1 * 16); + acc[mi][2] = _mm512_loadu_ps(op[mi] + 2 * 16); + acc[mi][3] = _mm512_loadu_ps(op[mi] + 3 * 16); + acc[mi][4] = _mm512_loadu_ps(op[mi] + 4 * 16); + acc[mi][5] = _mm512_loadu_ps(op[mi] + 5 * 16); + acc[mi][6] = _mm512_loadu_ps(op[mi] + 6 * 16); + acc[mi][7] = _mm512_loadu_ps(op[mi] + 7 * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 0 * 16))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 1 * 16))); + __m512 v2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 2 * 16))); + __m512 v3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 3 * 16))); + __m512 v4 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 4 * 16))); + __m512 v5 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 5 * 16))); + __m512 v6 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 6 * 16))); + __m512 v7 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 7 * 16))); + + for (int mi = 0; mi < 4; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); + acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); + acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); + acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); + acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0 * 16, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 1 * 16, acc[mi][1]); + _mm512_storeu_ps(op[mi] + 2 * 16, acc[mi][2]); + _mm512_storeu_ps(op[mi] + 3 * 16, acc[mi][3]); + _mm512_storeu_ps(op[mi] + 4 * 16, acc[mi][4]); + _mm512_storeu_ps(op[mi] + 5 * 16, acc[mi][5]); + _mm512_storeu_ps(op[mi] + 6 * 16, acc[mi][6]); + _mm512_storeu_ps(op[mi] + 7 * 16, acc[mi][7]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc0 = _mm512_loadu_ps(optr + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + 7 * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 0 * 16))), acc0); + acc1 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 1 * 16))), acc1); + acc2 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 2 * 16))), acc2); + acc3 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 3 * 16))), acc3); + acc4 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 4 * 16))), acc4); + acc5 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 5 * 16))), acc5); + acc6 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 6 * 16))), acc6); + acc7 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 7 * 16))), acc7); + } + + _mm512_storeu_ps(optr + 0 * 16, acc0); + _mm512_storeu_ps(optr + 1 * 16, acc1); + _mm512_storeu_ps(optr + 2 * 16, acc2); + _mm512_storeu_ps(optr + 3 * 16, acc3); + _mm512_storeu_ps(optr + 4 * 16, acc4); + _mm512_storeu_ps(optr + 5 * 16, acc5); + _mm512_storeu_ps(optr + 6 * 16, acc6); + _mm512_storeu_ps(optr + 7 * 16, acc7); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + const float* pptr[4]; + for (int mi = 0; mi < 4; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m512 vvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd))); + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd))), acc); + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __AVX512F__ + +#if __AVX__ +static void pv_gemm_bf16s_avx_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 31 < d; dd += 32) + { + int i = 0; + for (; i + 2 <= m; i += 2) + { + float* op[2]; + const float* pptr[2]; + for (int mi = 0; mi < 2; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m256 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_loadu_ps(op[mi] + 0 * 8); + acc[mi][1] = _mm256_loadu_ps(op[mi] + 1 * 8); + acc[mi][2] = _mm256_loadu_ps(op[mi] + 2 * 8); + acc[mi][3] = _mm256_loadu_ps(op[mi] + 3 * 8); + } + + for (int j = 0; j < n; j++) + { + __m256 v0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 0 * 8))); + __m256 v1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 1 * 8))); + __m256 v2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 2 * 8))); + __m256 v3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 3 * 8))); + + for (int mi = 0; mi < 2; mi++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm256_comp_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(pvec, v3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + _mm256_storeu_ps(op[mi] + 0 * 8, acc[mi][0]); + _mm256_storeu_ps(op[mi] + 1 * 8, acc[mi][1]); + _mm256_storeu_ps(op[mi] + 2 * 8, acc[mi][2]); + _mm256_storeu_ps(op[mi] + 3 * 8, acc[mi][3]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m256 acc0 = _mm256_loadu_ps(optr + 0 * 8); + __m256 acc1 = _mm256_loadu_ps(optr + 1 * 8); + __m256 acc2 = _mm256_loadu_ps(optr + 2 * 8); + __m256 acc3 = _mm256_loadu_ps(optr + 3 * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + acc0 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 0 * 8))), acc0); + acc1 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 1 * 8))), acc1); + acc2 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 2 * 8))), acc2); + acc3 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 3 * 8))), acc3); + } + + _mm256_storeu_ps(optr + 0 * 8, acc0); + _mm256_storeu_ps(optr + 1 * 8, acc1); + _mm256_storeu_ps(optr + 2 * 8, acc2); + _mm256_storeu_ps(optr + 3 * 8, acc3); + } + } + + for (; dd + 7 < d; dd += 8) + { + int i = 0; + for (; i + 2 <= m; i += 2) + { + float* op[2]; + const float* pptr[2]; + for (int mi = 0; mi < 2; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + __m256 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256 vvec = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd))); + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + for (int mi = 0; mi < 2; mi++) + _mm256_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m256 acc = _mm256_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd))), acc); + _mm256_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __AVX__ + +#if __SSE2__ +static void pv_gemm_bf16s_sse2_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 15 < d; dd += 16) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m128 acc0 = _mm_loadu_ps(optr + 0 * 4); + __m128 acc1 = _mm_loadu_ps(optr + 1 * 4); + __m128 acc2 = _mm_loadu_ps(optr + 2 * 4); + __m128 acc3 = _mm_loadu_ps(optr + 3 * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + acc0 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 0 * 4))), acc0); + acc1 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 1 * 4))), acc1); + acc2 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 2 * 4))), acc2); + acc3 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 3 * 4))), acc3); + } + + _mm_storeu_ps(optr + 0 * 4, acc0); + _mm_storeu_ps(optr + 1 * 4, acc1); + _mm_storeu_ps(optr + 2 * 4, acc2); + _mm_storeu_ps(optr + 3 * 4, acc3); + } + } + + for (; dd + 3 < d; dd += 4) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m128 acc = _mm_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd))), acc); + _mm_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __SSE2__ + +static void pv_gemm_bf16s_scalar_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + for (int i = 0; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + { + float p = pptr[j]; + const unsigned short* vptr = V + j * d; + for (int k = 0; k < d; k++) + optr[k] += p * bfloat16_to_float32(vptr[k]); + } + } +} + +#if __AVX512F__ && __AVX512BF16__ + +static void decode_qk_dot_bf16s_avx512bf16_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int n_end = n_start + block_n; + int j = n_start; + for (; j + 1 < n_end; j += 2) + { + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q + k)); + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(K + j * d + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(K + (j + 1) * d + k)); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)qv, (__m512bh)kv0); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)qv, (__m512bh)kv1); + } + if (k + 15 < d) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q + k)), 0); + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + j * d + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + (j + 1) * d + k)), 0); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)qv, (__m512bh)kv0); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)qv, (__m512bh)kv1); + k += 16; + } + float tail_sum0 = 0, tail_sum1 = 0; + for (; k < d; k++) + { + float qv = q[k]; + tail_sum0 += qv * bfloat16_to_float32(K[j * d + k]); + tail_sum1 += qv * bfloat16_to_float32(K[(j + 1) * d + k]); + } + s[j] = (_mm512_comp_reduce_add_ps(acc0) + tail_sum0) * scale; + s[j + 1] = (_mm512_comp_reduce_add_ps(acc1) + tail_sum1) * scale; + } + for (; j < n_end; j++) + { + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q + k)); + __m512i kv = _mm512_loadu_si512((const __m512i*)(K + j * d + k)); + acc = _mm512_dpbf16_ps(acc, (__m512bh)qv, (__m512bh)kv); + } + if (k + 15 < d) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q + k)), 0); + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + j * d + k)), 0); + acc = _mm512_dpbf16_ps(acc, (__m512bh)qv, (__m512bh)kv); + k += 16; + } + float tail_sum = 0; + for (; k < d; k++) + tail_sum += q[k] * bfloat16_to_float32(K[j * d + k]); + s[j] = (_mm512_comp_reduce_add_ps(acc) + tail_sum) * scale; + } +} + +static void decode_pv_gemv_bf16s_avx512bf16_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int n_end = n_start + block_n; + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + for (int j = n_start; j < n_end; j++) + { + __m512 sj = _mm512_set1_ps(s[j]); + __m512i v0 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k))); + __m512i v1 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v0, 16)), oval0); + oval1 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v1, 16)), oval1); + } + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + for (int j = n_start; j < n_end; j++) + { + __m512 sj = _mm512_set1_ps(s[j]); + __m512i v0 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k))); + oval0 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v0, 16)), oval0); + } + _mm512_storeu_ps(out + k, oval0); + k += 16; + } + for (; k < out_d; k++) + { + for (int j = n_start; j < n_end; j++) + out[k] += s[j] * bfloat16_to_float32(V[j * out_d + k]); + } +} + +static void qk_gemm_bf16s_avx512bf16_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * d * sizeof(unsigned short), 64); + + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + unsigned short* qdst = q_bf16 + i * d; + int k = 0; + for (; k + 15 < d; k += 16) + { + __m256i t = float2bfloat_avx512(_mm512_loadu_ps(qptr + k)); + _mm256_storeu_si256((__m256i*)(qdst + k), t); + } + for (; k < d; k++) + qdst[k] = float32_to_bfloat16(qptr[k]); + } + + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + qk_gemm_bf16s_avx512_kernel(S + i * n, Q + i * d, K, m - i, n, d, scale); + } + + _mm_free(q_bf16); +} + +static void qk_gemm_bf16s_avx512bf16_qbf16_kernel(float* S, const unsigned short* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + for (int ii = i; ii < m; ii++) + for (int j = 0; j < n; j++) + { + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += bfloat16_to_float32(Q[ii * d + k]) * bfloat16_to_float32(K[j * d + k]); + S[ii * n + j] = sum * scale; + } + } +} + +static void pv_gemm_bf16s_avx512bf16_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); + for (int i = 0; i < m; i++) + { + const float* pptr = P + i * n; + unsigned short* pdst = p_bf16 + i * n; + int j = 0; + for (; j + 15 < n; j += 16) + { + __m512 pvec = _mm512_loadu_ps(pptr + j); + __m256i pbf16 = (__m256i)_mm512_cvtneps_pbh(pvec); + _mm256_storeu_si256((__m256i*)(pdst + j), pbf16); + } + for (; j < n; j++) + pdst[j] = float32_to_bfloat16(pptr[j]); + } + + int dd = 0; + for (; dd + 31 < d; dd += 32) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * d + dd; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + 0); + acc[mi][1] = _mm512_loadu_ps(op[mi] + 16); + } + + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)p_ext, (__m512bh)v0_ext); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)p_ext, (__m512bh)v1_ext); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 16, acc[mi][1]); + } + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * d + dd; + float* op1 = O + (i + 1) * d + dd; + __m512 acc00 = _mm512_loadu_ps(op0 + 0); + __m512 acc01 = _mm512_loadu_ps(op0 + 16); + __m512 acc10 = _mm512_loadu_ps(op1 + 0); + __m512 acc11 = _mm512_loadu_ps(op1 + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc00 = _mm512_dpbf16_ps(acc00, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc01 = _mm512_dpbf16_ps(acc01, (__m512bh)p_ext0, (__m512bh)v1_ext); + acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); + acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); + } + _mm512_storeu_ps(op0 + 0, acc00); + _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); + _mm512_storeu_ps(op1 + 16, acc11); + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + __m512 acc0 = _mm512_loadu_ps(optr + 0); + __m512 acc1 = _mm512_loadu_ps(optr + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext, (__m512bh)v1_ext); + } + _mm512_storeu_ps(optr + 0, acc0); + _mm512_storeu_ps(optr + 16, acc1); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * d + dd; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)p_ext, (__m512bh)v0_ext); + } + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * d + dd; + float* op1 = O + (i + 1) * d + dd; + __m512 acc0 = _mm512_loadu_ps(op0); + __m512 acc1 = _mm512_loadu_ps(op1); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext1, (__m512bh)v0_ext); + } + _mm512_storeu_ps(op0, acc0); + _mm512_storeu_ps(op1, acc1); + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc = _mm512_dpbf16_ps(acc, (__m512bh)p_ext, (__m512bh)v0_ext); + } + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += bfloat16_to_float32(p_bf16[i * n + j]) * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } + + _mm_free(p_bf16); +} + +template +static void qk_gemm_bf16s_avx512bf16_kernel_t(float* S, const float* Q, const unsigned short* K, + int m, int n, float scale) +{ + unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * D * sizeof(unsigned short), 64); + + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * D; + unsigned short* qdst = q_bf16 + i * D; + int k = 0; + for (; k + 15 < D; k += 16) + { + __m256i t = float2bfloat_avx512(_mm512_loadu_ps(qptr + k)); + _mm256_storeu_si256((__m256i*)(qdst + k), t); + } + for (; k < D; k++) + qdst[k] = float32_to_bfloat16(qptr[k]); + } + + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * D; + const unsigned short* k1 = K + (j + 1) * D; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < D) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * D; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < D) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * D; + const unsigned short* k1 = K + (j + 1) * D; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < D) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * D; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < D) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + // scalar tail fallback not available in template + for (int ii = i; ii < m; ii++) + for (int j = 0; j < n; j++) + { + float sum = 0; + for (int k = 0; k < D; k++) + sum += Q[ii * D + k] * bfloat16_to_float32(K[j * D + k]); + S[ii * n + j] = sum * scale; + } + } + + _mm_free(q_bf16); +} + +template +static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n, bool init_zero) +{ + unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); + for (int i = 0; i < m; i++) + { + const float* pptr = P + i * n; + unsigned short* pdst = p_bf16 + i * n; + int j = 0; + for (; j + 15 < n; j += 16) + { + __m512 pvec = _mm512_loadu_ps(pptr + j); + __m256i pbf16 = (__m256i)_mm512_cvtneps_pbh(pvec); + _mm256_storeu_si256((__m256i*)(pdst + j), pbf16); + } + for (; j < n; j++) + pdst[j] = float32_to_bfloat16(pptr[j]); + } + + int dd = 0; + for (; dd + 31 < D; dd += 32) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * D + dd; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + 0); + acc[mi][1] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + 16); + } + + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)p_ext, (__m512bh)v0_ext); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)p_ext, (__m512bh)v1_ext); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 16, acc[mi][1]); + } + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * D + dd; + float* op1 = O + (i + 1) * D + dd; + __m512 acc00 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0 + 0); + __m512 acc01 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0 + 16); + __m512 acc10 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1 + 0); + __m512 acc11 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1 + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc00 = _mm512_dpbf16_ps(acc00, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc01 = _mm512_dpbf16_ps(acc01, (__m512bh)p_ext0, (__m512bh)v1_ext); + acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); + acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); + } + _mm512_storeu_ps(op0 + 0, acc00); + _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); + _mm512_storeu_ps(op1 + 16, acc11); + } + for (; i < m; i++) + { + float* optr = O + i * D + dd; + __m512 acc0 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + 0); + __m512 acc1 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext, (__m512bh)v1_ext); + } + _mm512_storeu_ps(optr + 0, acc0); + _mm512_storeu_ps(optr + 16, acc1); + } + } + + for (; dd + 15 < D; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * D + dd; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)p_ext, (__m512bh)v0_ext); + } + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * D + dd; + float* op1 = O + (i + 1) * D + dd; + __m512 acc0 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0); + __m512 acc1 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext1, (__m512bh)v0_ext); + } + _mm512_storeu_ps(op0, acc0); + _mm512_storeu_ps(op1, acc1); + } + for (; i < m; i++) + { + float* optr = O + i * D + dd; + __m512 acc = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc = _mm512_dpbf16_ps(acc, (__m512bh)p_ext, (__m512bh)v0_ext); + } + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < D; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * D + dd; + float acc = init_zero ? 0.f : optr[0]; + for (int j = 0; j < n; j++) + acc += bfloat16_to_float32(p_bf16[i * n + j]) * bfloat16_to_float32(V[j * D + dd]); + optr[0] = acc; + } + } + + _mm_free(p_bf16); +} + +template +static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n) +{ + pv_gemm_bf16s_avx512bf16_kernel_t(O, P, V, m, n, false); +} + +#endif // __AVX512F__ && __AVX512BF16__ + +#endif // SDPA_X86_BF16S_H diff --git a/src/layer/x86/sdpa_x86_int8.h b/src/layer/x86/sdpa_x86_int8.h new file mode 100644 index 000000000000..e30496424098 --- /dev/null +++ b/src/layer/x86/sdpa_x86_int8.h @@ -0,0 +1,3331 @@ +#ifndef SDPA_X86_INT8_H +#define SDPA_X86_INT8_H + +static inline signed char float2int8(float v) +{ + int int32 = static_cast(roundf(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} + +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avx512vnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avx512vnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avx512vnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avxvnniint8(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avxvnniint8(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avxvnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avxvnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avxvnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ +void dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width); +void dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width); +int qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len); +void decode_qk_dot_int8_avx2(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avx2(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avx2(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +void decode_pv_gemv_int8_avx2(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d); +void pv_float_int8_gemm_row_avx2(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d); +void pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len); +void pv_float_int8_gemm_tile_avx2(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_xop(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_xop(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_xop(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +static void dynamic_quantize_blockwise_scalar_kernel(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + for (int i = start; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + for (int i = start; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} + +#if __SSE2__ +static inline void dynamic_quantize_blockwise_sse2_kernel(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m128 sign_mask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31)); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m128 vmax = _mm_setzero_ps(); + for (; i + 3 < end; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + __m128 ax = _mm_andnot_ps(sign_mask, x); + vmax = _mm_max_ps(vmax, ax); + } + absmax = _mm_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m128 vscale = _mm_set1_ps(scale); + i = start; + for (; i + 15 < end; i += 16) + { + __m128 x0 = _mm_loadu_ps(src + i); + __m128 x1 = _mm_loadu_ps(src + i + 4); + __m128 x2 = _mm_loadu_ps(src + i + 8); + __m128 x3 = _mm_loadu_ps(src + i + 12); + x0 = _mm_mul_ps(x0, vscale); + x1 = _mm_mul_ps(x1, vscale); + x2 = _mm_mul_ps(x2, vscale); + x3 = _mm_mul_ps(x3, vscale); + __m128i v8 = float2int8_sse(x0, x1, x2, x3); + _mm_storeu_si128((__m128i*)(dst + i), v8); + } + for (; i + 3 < end; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + x = _mm_mul_ps(x, vscale); + int32_t v4 = float2int8_sse(x); + *(int32_t*)(dst + i) = v4; + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void dynamic_quantize_blockwise_avx2_kernel(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m256 sign_mask = _mm256_set1_ps(-0.f); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m256 vmax = _mm256_setzero_ps(); + for (; i + 7 < end; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + vmax = _mm256_max_ps(vmax, ax); + } + absmax = _mm256_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m256 vscale = _mm256_set1_ps(scale); + i = start; + for (; i + 31 < end; i += 32) + { + __m256 x0 = _mm256_loadu_ps(src + i); + __m256 x1 = _mm256_loadu_ps(src + i + 8); + __m256 x2 = _mm256_loadu_ps(src + i + 16); + __m256 x3 = _mm256_loadu_ps(src + i + 24); + __m256i y0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x0, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x1, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x2, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x3, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm256_castsi256_si128(y0); + __m128i a1 = _mm256_extractf128_si256(y0, 1); + __m128i a2 = _mm256_castsi256_si128(y1); + __m128i a3 = _mm256_extractf128_si256(y1, 1); + __m128i a4 = _mm256_castsi256_si128(y2); + __m128i a5 = _mm256_extractf128_si256(y2, 1); + __m128i a6 = _mm256_castsi256_si128(y3); + __m128i a7 = _mm256_extractf128_si256(y3, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i b2 = _mm_packs_epi32(a4, a5); + __m128i b3 = _mm_packs_epi32(a6, a7); + __m128i c0 = _mm_packs_epi16(b0, b1); + __m128i c1 = _mm_packs_epi16(b2, b3); + _mm_storeu_si128((__m128i*)(dst + i), c0); + _mm_storeu_si128((__m128i*)(dst + i + 16), c1); + } + for (; i + 7 < end; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256i y = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm256_castsi256_si128(y); + __m128i a1 = _mm256_extractf128_si256(y, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i c0 = _mm_packs_epi16(b0, b0); + int64_t v8 = _mm_cvtsi128_si64(c0); + *(int64_t*)(dst + i) = v8; + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void dynamic_quantize_blockwise_avx512_kernel(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m512 sign_mask = _mm512_set1_ps(-0.f); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m512 vmax = _mm512_setzero_ps(); + for (; i + 15 < end; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512 ax = _mm512_andnot_ps(sign_mask, x); + vmax = _mm512_max_ps(vmax, ax); + } + absmax = _mm512_comp_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m512 vscale = _mm512_set1_ps(scale); + i = start; + for (; i + 15 < end; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512i y = _mm512_cvtps_epi32(_mm512_roundscale_ps(_mm512_mul_ps(x, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm512_extracti32x4_epi32(y, 0); + __m128i a1 = _mm512_extracti32x4_epi32(y, 1); + __m128i a2 = _mm512_extracti32x4_epi32(y, 2); + __m128i a3 = _mm512_extracti32x4_epi32(y, 3); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i c0 = _mm_packs_epi16(b0, b1); + _mm_storeu_si128((__m128i*)(dst + i), c0); + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __AVX512F__ + +static void dynamic_quantize_blockwise(const float* src, signed char* dst, float* scales, int width) +{ +#if __AVX512F__ + dynamic_quantize_blockwise_avx512_kernel(src, dst, scales, width); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + dynamic_quantize_blockwise_avx2(src, dst, scales, width); + return; + } +#endif +#if __AVX2__ + dynamic_quantize_blockwise_avx2(src, dst, scales, width); +#elif __SSE2__ + dynamic_quantize_blockwise_sse2_kernel(src, dst, scales, width); +#else + dynamic_quantize_blockwise_scalar_kernel(src, dst, scales, width); +#endif +#endif +} + +static inline void dynamic_quantize_rowwise_scalar_kernel(const float* src, signed char* dst, float* scale, int width) +{ + float absmax = 0.f; + for (int i = 0; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + for (int i = 0; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} + +#if __SSE2__ +static inline void dynamic_quantize_rowwise_sse2_kernel(const float* src, signed char* dst, float* scale, int width) +{ + __m128 sign_mask = _mm_set1_ps(-0.f); + __m128 vmax = _mm_setzero_ps(); + int i = 0; + for (; i + 3 < width; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + vmax = _mm_max_ps(vmax, _mm_andnot_ps(sign_mask, x)); + } + float absmax = _mm_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m128 vscale = _mm_set1_ps(s); + i = 0; + for (; i + 3 < width; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + __m128i yi = _mm_cvtps_epi32(_mm_mul_ps(x, vscale)); + (void)yi; + dst[i + 0] = float2int8(src[i + 0] * s); + dst[i + 1] = float2int8(src[i + 1] * s); + dst[i + 2] = float2int8(src[i + 2] * s); + dst[i + 3] = float2int8(src[i + 3] * s); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +#if __AVX2__ +inline void dynamic_quantize_rowwise_avx2_kernel(const float* src, signed char* dst, float* scale, int width) +{ + __m256 sign_mask = _mm256_set1_ps(-0.f); + __m256 vmax = _mm256_setzero_ps(); + int i = 0; + for (; i + 7 < width; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + vmax = _mm256_max_ps(vmax, _mm256_andnot_ps(sign_mask, x)); + } + float absmax = _mm256_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m256 vscale = _mm256_set1_ps(s); + i = 0; + for (; i + 7 < width; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256i y = _mm256_cvtps_epi32(_mm256_mul_ps(x, vscale)); + __m128i a0 = _mm256_extracti128_si256(y, 0); + __m128i a1 = _mm256_extracti128_si256(y, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + _mm_storel_epi64((__m128i*)(dst + i), _mm_packs_epi16(b0, b0)); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +#if __AVX512F__ +static inline void dynamic_quantize_rowwise_avx512_kernel(const float* src, signed char* dst, float* scale, int width) +{ + __m512 sign_mask = _mm512_set1_ps(-0.f); + __m512 vmax = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < width; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + vmax = _mm512_max_ps(vmax, _mm512_andnot_ps(sign_mask, x)); + } + float absmax = _mm512_comp_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m512 vscale = _mm512_set1_ps(s); + i = 0; + for (; i + 15 < width; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512i y = _mm512_cvtps_epi32(_mm512_mul_ps(x, vscale)); + __m128i a0 = _mm512_extracti32x4_epi32(y, 0); + __m128i a1 = _mm512_extracti32x4_epi32(y, 1); + __m128i a2 = _mm512_extracti32x4_epi32(y, 2); + __m128i a3 = _mm512_extracti32x4_epi32(y, 3); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i c0 = _mm_packs_epi16(b0, b1); + _mm_storeu_si128((__m128i*)(dst + i), c0); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +static void dynamic_quantize_rowwise(const float* src, signed char* dst, float* scale, int width) +{ +#if __AVX512F__ + dynamic_quantize_rowwise_avx512_kernel(src, dst, scale, width); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + dynamic_quantize_rowwise_avx2(src, dst, scale, width); + return; + } +#endif +#if __AVX2__ + dynamic_quantize_rowwise_avx2(src, dst, scale, width); +#elif __SSE2__ + dynamic_quantize_rowwise_sse2_kernel(src, dst, scale, width); +#else + dynamic_quantize_rowwise_scalar_kernel(src, dst, scale, width); +#endif +#endif +} + +static inline void reciprocal_scales(float* scales, int num_blocks) +{ + int i = 0; +#if __AVX512F__ + for (; i + 15 < num_blocks; i += 16) + { + __m512 v = _mm512_loadu_ps(scales + i); + _mm512_storeu_ps(scales + i, _mm512_div_ps(_mm512_set1_ps(1.0f), v)); + } +#elif __AVX__ + for (; i + 7 < num_blocks; i += 8) + { + __m256 v = _mm256_loadu_ps(scales + i); + _mm256_storeu_ps(scales + i, _mm256_div_ps(_mm256_set1_ps(1.0f), v)); + } +#elif __SSE2__ + for (; i + 3 < num_blocks; i += 4) + { + __m128 v = _mm_loadu_ps(scales + i); + _mm_storeu_ps(scales + i, _mm_div_ps(_mm_set1_ps(1.0f), v)); + } +#endif + for (; i < num_blocks; i++) + { + scales[i] = 1.0f / scales[i]; + } +} + +// =================== Int8 SIMD Kernels =================== + +static inline int qk_int8_dot_block_scalar_kernel(const signed char* a, const signed char* b, int len) +{ + int sum = 0; + for (int i = 0; i < len; i++) + sum += a[i] * b[i]; + return sum; +} + +#if __SSE2__ && !__SSSE3__ +// SSE2-compatible int8 sign-extension helpers (SSSE3 provides _mm_cvtepi8_epi16) +static inline __m128i _mm_cvtepi8_epi16_sse2(__m128i x) +{ + __m128i sign = _mm_cmpgt_epi8(_mm_setzero_si128(), x); + return _mm_unpacklo_epi8(x, sign); +} +#define _mm_cvtepi8_epi16 _mm_cvtepi8_epi16_sse2 +#endif +#if __SSE2__ && !__SSE4_1__ +static inline __m128i _mm_cvtepi8_epi32_sse2(__m128i x) +{ + __m128i sign = _mm_cmpgt_epi8(_mm_setzero_si128(), x); + __m128i x16 = _mm_unpacklo_epi8(x, sign); + __m128i sign16 = _mm_cmpgt_epi16(_mm_setzero_si128(), x16); + return _mm_unpacklo_epi16(x16, sign16); +} +#define _mm_cvtepi8_epi32 _mm_cvtepi8_epi32_sse2 +#endif + +#if __SSE2__ +static inline int qk_int8_dot_block_sse2_kernel(const signed char* a, const signed char* b, int len) +{ + __m128i sum = _mm_setzero_si128(); + int i = 0; + for (; i + 15 < len; i += 16) + { + __m128i va = _mm_loadu_si128((const __m128i*)(a + i)); + __m128i vb = _mm_loadu_si128((const __m128i*)(b + i)); + __m128i va_lo = _mm_cvtepi8_epi16(va); + __m128i va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8)); + __m128i vb_lo = _mm_cvtepi8_epi16(vb); + __m128i vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8)); + sum = _mm_add_epi32(sum, _mm_madd_epi16(va_lo, vb_lo)); + sum = _mm_add_epi32(sum, _mm_madd_epi16(va_hi, vb_hi)); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + return _mm_reduce_add_epi32(sum) + sum_tail; +} +#endif // __SSE2__ + +#if __XOP__ +static inline int qk_int8_dot_block_xop_kernel(const signed char* a, const signed char* b, int len) +{ + __m128i sum = _mm_setzero_si128(); + int i = 0; + for (; i + 15 < len; i += 16) + { + __m128i va = _mm_loadu_si128((const __m128i*)(a + i)); + __m128i vb = _mm_loadu_si128((const __m128i*)(b + i)); + __m128i va_lo = _mm_cvtepi8_epi16(va); + __m128i va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8)); + __m128i vb_lo = _mm_cvtepi8_epi16(vb); + __m128i vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8)); + sum = _mm_maccd_epi16(va_lo, vb_lo, sum); + sum = _mm_maccd_epi16(va_hi, vb_hi, sum); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + return _mm_reduce_add_epi32(sum) + sum_tail; +} +#endif // __XOP__ + +#if __AVX2__ +inline int qk_int8_dot_block_avx2_kernel(const signed char* a, const signed char* b, int len) +{ + __m256i sum = _mm256_setzero_si256(); + int i = 0; + for (; i + 31 < len; i += 32) + { + __m256i va = _mm256_loadu_si256((const __m256i*)(a + i)); + __m256i vb = _mm256_loadu_si256((const __m256i*)(b + i)); + __m256i va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va)); + __m256i va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1)); + __m256i vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb)); + __m256i vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1)); + sum = _mm256_comp_dpwssd_epi32(sum, va_lo, vb_lo); + sum = _mm256_comp_dpwssd_epi32(sum, va_hi, vb_hi); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + return _mm_reduce_add_epi32(sum128) + sum_tail; +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline int qk_int8_dot_block_avx512_kernel(const signed char* a, const signed char* b, int len) +{ + __m512i sum = _mm512_setzero_si512(); + int i = 0; + for (; i + 31 < len; i += 32) + { + __m256i va256 = _mm256_loadu_si256((const __m256i*)(a + i)); + __m256i vb256 = _mm256_loadu_si256((const __m256i*)(b + i)); + __m512i va = _mm512_cvtepi8_epi16(va256); + __m512i vb = _mm512_cvtepi8_epi16(vb256); + sum = _mm512_comp_dpwssd_epi32(sum, va, vb); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + __m256i sum256 = _mm256_add_epi32(_mm512_castsi512_si256(sum), _mm512_extracti32x8_epi32(sum, 1)); + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum256), _mm256_extracti128_si256(sum256, 1)); + return _mm_reduce_add_epi32(sum128) + sum_tail; +} +#endif // __AVX512F__ + +static inline int qk_int8_dot_block(const signed char* a, const signed char* b, int len) +{ +#if __AVX512F__ + return qk_int8_dot_block_avx512_kernel(a, b, len); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + return qk_int8_dot_block_avx2(a, b, len); + } +#endif +#if __AVX2__ + return qk_int8_dot_block_avx2(a, b, len); +#elif __SSE2__ + return qk_int8_dot_block_sse2_kernel(a, b, len); +#else + return qk_int8_dot_block_scalar_kernel(a, b, len); +#endif +#endif +} + +// ------------------- Decode QK Dot Int8 ------------------- + +static inline void decode_qk_dot_int8_scalar_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + for (int j = 0; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j); + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_scalar_kernel(q + off, kptr + off, len); + sum += (float)block_sum / (qscales[0] * ks[0]); + } + s[j] = sum * scale; + } +} + +#if __SSE2__ +static inline void decode_qk_dot_int8_sse2_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); + + float sum0 = 0.f, sum1 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + float descale = qscales[0] * ks0[0]; + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_sse2_kernel(q + off, k0 + off, len); + sum0 += (float)block_sum * descale; + } + for (int b = 0; b < num_blocks; b++) + { + float descale = qscales[0] * ks1[0]; + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_sse2_kernel(q + off, k1 + off, len); + sum1 += (float)block_sum * descale; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j); + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_scalar_kernel(q + off, kptr + off, len); + float descale = qscales[0] * ks[0]; + sum += (float)block_sum * descale; + } + s[j] = sum * scale; + } +} +#endif // __SSE2__ + +#if __XOP__ +static inline void decode_qk_dot_int8_xop_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); + + float sum0 = 0.f, sum1 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum0 = qk_int8_dot_block_xop_kernel(q + off, k0 + off, len); + int block_sum1 = qk_int8_dot_block_xop_kernel(q + off, k1 + off, len); + sum0 += (float)block_sum0 * qscales[0] * ks0[0]; + sum1 += (float)block_sum1 * qscales[0] * ks1[0]; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j); + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_xop_kernel(q + off, kptr + off, len); + sum += (float)block_sum * qscales[0] * ks[0]; + } + s[j] = sum * scale; + } +} +#endif // __XOP__ + +#if __AVX2__ +inline void decode_qk_dot_int8_avx2_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); + + float sum0 = 0.f, sum1 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum0 = qk_int8_dot_block_avx2_kernel(q + off, k0 + off, len); + int block_sum1 = qk_int8_dot_block_avx2_kernel(q + off, k1 + off, len); + sum0 += (float)block_sum0 / (qscales[0] * ks0[0]); + sum1 += (float)block_sum1 / (qscales[0] * ks1[0]); + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j); + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_avx2_kernel(q + off, kptr + off, len); + float descale = qscales[0] * ks[0]; + sum += (float)block_sum * descale; + } + s[j] = sum * scale; + } +} +#endif // __AVX2__ + +#if __AVX2__ +static inline int _mm256_reduce_add_epi32(__m256i v); +#endif + +#if __AVXVNNI__ +static inline void decode_qk_dot_int8_avxvnni_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const float qscale = qscales[0]; + const int num_blocks_32 = d / 32; + + // Precompute qsum for dpbusd compensation: sum((q+128)*k) = sum(q*k) + 128*sum(k) + int qsum = 0; + { + __m256i qsum_acc = _mm256_setzero_si256(); + const __m256i ones = _mm256_set1_epi8(1); + for (int b = 0; b < num_blocks_32; b++) + { + __m256i q_256 = _mm256_loadu_si256((const __m256i*)(q + b * 32)); + qsum_acc = _mm256_dpbusd_epi32(qsum_acc, ones, q_256); + } + qsum = _mm256_reduce_add_epi32(qsum_acc); + for (int i = num_blocks_32 * 32; i < d; i++) + qsum += q[i]; + } + + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const signed char* k2 = K + (n_start + j + 2) * d; + const signed char* k3 = K + (n_start + j + 3) * d; + + __m256i acc0 = _mm256_setzero_si256(); + __m256i acc1 = _mm256_setzero_si256(); + __m256i acc2 = _mm256_setzero_si256(); + __m256i acc3 = _mm256_setzero_si256(); + const __m256i offset128 = _mm256_set1_epi8(128); + + for (int b = 0; b < num_blocks_32; b++) + { + int off = b * 32; + __m256i q_u8 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i*)(q + off)), offset128); + + acc0 = _mm256_dpbusd_epi32(acc0, q_u8, _mm256_loadu_si256((const __m256i*)(k0 + off))); + acc1 = _mm256_dpbusd_epi32(acc1, q_u8, _mm256_loadu_si256((const __m256i*)(k1 + off))); + acc2 = _mm256_dpbusd_epi32(acc2, q_u8, _mm256_loadu_si256((const __m256i*)(k2 + off))); + acc3 = _mm256_dpbusd_epi32(acc3, q_u8, _mm256_loadu_si256((const __m256i*)(k3 + off))); + } + + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int i = num_blocks_32 * 32; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm256_reduce_add_epi32(acc0) - 128 * qsum + scalar0); + float sum1 = (float)(_mm256_reduce_add_epi32(acc1) - 128 * qsum + scalar1); + float sum2 = (float)(_mm256_reduce_add_epi32(acc2) - 128 * qsum + scalar2); + float sum3 = (float)(_mm256_reduce_add_epi32(acc3) - 128 * qsum + scalar3); + + s[j + 0] = sum0 * qscale * kscales[n_start + j + 0] * scale; + s[j + 1] = sum1 * qscale * kscales[n_start + j + 1] * scale; + s[j + 2] = sum2 * qscale * kscales[n_start + j + 2] * scale; + s[j + 3] = sum3 * qscale * kscales[n_start + j + 3] * scale; + } + + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + __m256i acc = _mm256_setzero_si256(); + const __m256i offset128 = _mm256_set1_epi8(128); + for (int b = 0; b < num_blocks_32; b++) + { + int off = b * 32; + __m256i q_u8 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i*)(q + off)), offset128); + acc = _mm256_dpbusd_epi32(acc, q_u8, _mm256_loadu_si256((const __m256i*)(kptr + off))); + } + int scalar = 0; + for (int i = num_blocks_32 * 32; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm256_reduce_add_epi32(acc) - 128 * qsum + scalar); + s[j] = sum * qscale * kscales[n_start + j] * scale; + } +} +#endif // __AVXVNNI__ + +#if __AVX512F__ +static inline void decode_qk_dot_int8_avx512_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const float qscale = qscales[0]; + const int num_blocks_64 = d / 64; + +#if __AVX512VNNI__ + // Precompute qsum for dpbusd compensation: sum((q+128)*k) = sum(q*k) + 128*sum(k) + int qsum = 0; + { + __m512i qsum_acc = _mm512_setzero_si512(); + const __m512i ones = _mm512_set1_epi8(1); + for (int b = 0; b < num_blocks_64; b++) + { + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q + b * 64)); + qsum_acc = _mm512_dpbusd_epi32(qsum_acc, ones, q_512); + } + qsum = _mm512_reduce_add_epi32(qsum_acc); + for (int i = num_blocks_64 * 64; i < d; i++) + qsum += q[i]; + } +#endif + + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const signed char* k2 = K + (n_start + j + 2) * d; + const signed char* k3 = K + (n_start + j + 3) * d; + +#if __AVX512VNNI__ + __m512i acc0 = _mm512_setzero_si512(); + __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + const __m512i offset128 = _mm512_set1_epi8(128); + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_u8 = _mm512_add_epi8(_mm512_loadu_si512((const __m512i*)(q + off)), offset128); + + acc0 = _mm512_dpbusd_epi32(acc0, q_u8, _mm512_loadu_si512((const __m512i*)(k0 + off))); + acc1 = _mm512_dpbusd_epi32(acc1, q_u8, _mm512_loadu_si512((const __m512i*)(k1 + off))); + acc2 = _mm512_dpbusd_epi32(acc2, q_u8, _mm512_loadu_si512((const __m512i*)(k2 + off))); + acc3 = _mm512_dpbusd_epi32(acc3, q_u8, _mm512_loadu_si512((const __m512i*)(k3 + off))); + } + + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int i = num_blocks_64 * 64; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum + scalar0); + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum + scalar1); + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum + scalar2); + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum + scalar3); +#else + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + + acc0 = _mm512_comp_dpwssd_epi32(acc0, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k0 + off)))); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k1 + off)))); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k2 + off)))); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k3 + off)))); + } + for (int i = num_blocks_64 * 64; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) + scalar0); + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) + scalar1); + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) + scalar2); + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) + scalar3); +#endif + s[j + 0] = sum0 * qscale * kscales[n_start + j + 0] * scale; + s[j + 1] = sum1 * qscale * kscales[n_start + j + 1] * scale; + s[j + 2] = sum2 * qscale * kscales[n_start + j + 2] * scale; + s[j + 3] = sum3 * qscale * kscales[n_start + j + 3] * scale; + } + + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; +#if __AVX512VNNI__ + __m512i acc = _mm512_setzero_si512(); + const __m512i offset128 = _mm512_set1_epi8(128); + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_u8 = _mm512_add_epi8(_mm512_loadu_si512((const __m512i*)(q + off)), offset128); + acc = _mm512_dpbusd_epi32(acc, q_u8, _mm512_loadu_si512((const __m512i*)(kptr + off))); + } + int scalar = 0; + for (int i = num_blocks_64 * 64; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm512_reduce_add_epi32(acc) - 128 * qsum + scalar); +#else + __m512i acc = _mm512_setzero_epi32(); + int scalar = 0; + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + acc = _mm512_comp_dpwssd_epi32(acc, + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(q + off))), + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(kptr + off)))); + } + for (int i = num_blocks_64 * 64; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm512_reduce_add_epi32(acc) + scalar); +#endif + s[j] = sum * qscale * kscales[n_start + j] * scale; + } +} +#endif // __AVX512F__ + +static void decode_qk_dot_int8(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + decode_qk_dot_int8_avx512vnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif + decode_qk_dot_int8_avx512_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + decode_qk_dot_int8_avx512vnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + decode_qk_dot_int8_avxvnniint8(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + decode_qk_dot_int8_avxvnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + decode_qk_dot_int8_avx2(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if __AVX2__ + decode_qk_dot_int8_avx2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#elif __SSE2__ + decode_qk_dot_int8_sse2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#else + decode_qk_dot_int8_scalar_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#endif +#endif +} + +// ------------------- Prefill QK Int8 GEMM Row-wise ------------------- + +static inline void qk_int8_gemm_row_scalar_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + for (int j = 0; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_scalar_kernel(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} + +#if __SSE2__ +static inline void qk_int8_gemm_row_sse2_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_sse2_kernel(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_sse2_kernel(q_row + off, k1 + off, len); + } + s_row[j] = (float)sum0 * qscale * kscales[j] * scale; + s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_scalar_kernel(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void qk_int8_gemm_row_avx2_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_avx2_kernel(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_avx2_kernel(q_row + off, k1 + off, len); + } + s_row[j] = (float)sum0 * qscale * kscales[j] * scale; + s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_avx2_kernel(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void qk_int8_gemm_row_avx512_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q_row + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); + __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); + + acc0 = _mm512_comp_dpwssd_epi32(acc0, q_512, k0_512); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q_512, k1_512); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q_512, k2_512); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q_512, k3_512); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k0 + off, len); + scalar0 += bs; + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k1 + off, len); + scalar1 += bs; + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k2 + off, len); + scalar2 += bs; + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k3 + off, len); + scalar3 += bs; + } + } + float descale0 = qscale * kscales[j]; + float descale1 = qscale * kscales[j + 1]; + float descale2 = qscale * kscales[j + 2]; + float descale3 = qscale * kscales[j + 3]; + float sum0 = (float)_mm512_reduce_add_epi32(acc0) + (float)scalar0; + float sum1 = (float)_mm512_reduce_add_epi32(acc1) + (float)scalar1; + float sum2 = (float)_mm512_reduce_add_epi32(acc2) + (float)scalar2; + float sum3 = (float)_mm512_reduce_add_epi32(acc3) + (float)scalar3; + s_row[j] = sum0 * descale0 * scale; + s_row[j + 1] = sum1 * descale1 * scale; + s_row[j + 2] = sum2 * descale2 * scale; + s_row[j + 3] = sum3 * descale3 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + __m512i acc = _mm512_setzero_epi32(); + int scalar_sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q_row + off)); + __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + __m512i k_512 = _mm512_cvtepi8_epi16(k_32); + acc = _mm512_comp_dpwssd_epi32(acc, q_512, k_512); + } + else + { + scalar_sum += qk_int8_dot_block_avx512_kernel(q_row + off, kptr + off, len); + } + } + float sum = (float)_mm512_reduce_add_epi32(acc) + (float)scalar_sum; + s_row[j] = sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX512F__ + +#if __AVX512VNNI__ +static inline void qk_int8_gemm_row_avx512vnni_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks_64 = d / 64; + int qsum_64byte = 0; + if (num_blocks_64 > 0) + { + __m512i ones = _mm512_set1_epi8(1); + __m512i sum_acc = _mm512_setzero_epi32(); + for (int b = 0; b < num_blocks_64; b++) + { + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + b * 64)); + sum_acc = _mm512_dpbusd_epi32(sum_acc, ones, q_512); + } + qsum_64byte = _mm512_reduce_add_epi32(sum_acc); + } + + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + off)); + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k2_512 = _mm512_loadu_si512((const __m512i*)(k2 + off)); + __m512i k3_512 = _mm512_loadu_si512((const __m512i*)(k3 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + __m512i k2_u8 = _mm512_add_epi8(k2_512, _mm512_set1_epi8(128)); + __m512i k3_u8 = _mm512_add_epi8(k3_512, _mm512_set1_epi8(128)); + + acc0 = _mm512_dpbusd_epi32(acc0, k0_u8, q_512); + acc1 = _mm512_dpbusd_epi32(acc1, k1_u8, q_512); + acc2 = _mm512_dpbusd_epi32(acc2, k2_u8, q_512); + acc3 = _mm512_dpbusd_epi32(acc3, k3_u8, q_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar0 += q_row[k] * k0[k]; + scalar1 += q_row[k] * k1[k]; + scalar2 += q_row[k] * k2[k]; + scalar3 += q_row[k] * k3[k]; + } + } + + float descale0 = qscale * kscales[j]; + float descale1 = qscale * kscales[j + 1]; + float descale2 = qscale * kscales[j + 2]; + float descale3 = qscale * kscales[j + 3]; + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum_64byte) + (float)scalar0; + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum_64byte) + (float)scalar1; + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum_64byte) + (float)scalar2; + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum_64byte) + (float)scalar3; + s_row[j] = sum0 * descale0 * scale; + s_row[j + 1] = sum1 * descale1 * scale; + s_row[j + 2] = sum2 * descale2 * scale; + s_row[j + 3] = sum3 * descale3 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + __m512i acc = _mm512_setzero_epi32(); + int scalar_sum = 0; + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + off)); + __m512i k_512 = _mm512_loadu_si512((const __m512i*)(kptr + off)); + __m512i k_u8 = _mm512_add_epi8(k_512, _mm512_set1_epi8(128)); + acc = _mm512_dpbusd_epi32(acc, k_u8, q_512); + } + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar_sum += q_row[k] * kptr[k]; + } + } + float sum = (float)(_mm512_reduce_add_epi32(acc) - 128 * qsum_64byte) + (float)scalar_sum; + s_row[j] = sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX512VNNI__ + +static void qk_int8_gemm_row(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_row_avx512vnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if __AVX512VNNI__ + qk_int8_gemm_row_avx512vnni_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#else + qk_int8_gemm_row_avx512_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#endif +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_row_avx512vnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + qk_int8_gemm_row_avxvnniint8(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + qk_int8_gemm_row_avxvnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + qk_int8_gemm_row_avx2(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if __AVX2__ + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#elif __SSE2__ + qk_int8_gemm_row_sse2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#else + qk_int8_gemm_row_scalar_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#endif +#endif +} + +// ------------------- Tiled QK Int8 GEMM (M-tiling) ------------------- + +static inline void qk_int8_gemm_tiled_scalar_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + for (int j = 0; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_scalar_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_scalar_kernel(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_scalar_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} + +#if __SSE2__ +static inline void qk_int8_gemm_tiled_sse2_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + int sum00 = 0, sum01 = 0, sum10 = 0, sum11 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int bs; + bs = qk_int8_dot_block_sse2_kernel(q0 + off, k0 + off, len); + sum00 += bs; + bs = qk_int8_dot_block_sse2_kernel(q0 + off, k1 + off, len); + sum01 += bs; + bs = qk_int8_dot_block_sse2_kernel(q1 + off, k0 + off, len); + sum10 += bs; + bs = qk_int8_dot_block_sse2_kernel(q1 + off, k1 + off, len); + sum11 += bs; + } + S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; + S[(i + 0) * n + j + 1] = (float)sum01 * qs0 * kscales[j + 1] * scale; + S[(i + 1) * n + j] = (float)sum10 * qs1 * kscales[j] * scale; + S[(i + 1) * n + j + 1] = (float)sum11 * qs1 * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_sse2_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_sse2_kernel(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_sse2_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __SSE2__ + +#if __AVX2__ +static inline int _mm256_reduce_add_epi32(__m256i v) +{ + __m128i vlow = _mm256_castsi256_si128(v); + __m128i vhigh = _mm256_extracti128_si256(v, 1); + vlow = _mm_add_epi32(vlow, vhigh); + return _mm_reduce_add_epi32(vlow); +} + +#if __AVX512F__ +static inline float _mm512_reduce_add_ps(__m512 v) +{ + __m256 vlow = _mm512_castps512_ps256(v); + __m256 vhigh = _mm512_extractf32x8_ps(v, 1); + return _mm256_reduce_add_ps(_mm256_add_ps(vlow, vhigh)); +} + +static inline int _mm512_reduce_add_epi32(__m512i v) +{ + __m256i vlow = _mm512_castsi512_si256(v); + __m256i vhigh = _mm512_extracti32x8_epi32(v, 1); + return _mm256_reduce_add_epi32(_mm256_add_epi32(vlow, vhigh)); +} +#endif // __AVX512F__ + +inline void qk_int8_gemm_tiled_avx2_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m256i acc00 = _mm256_setzero_si256(); + __m256i acc01 = _mm256_setzero_si256(); + __m256i acc02 = _mm256_setzero_si256(); + __m256i acc03 = _mm256_setzero_si256(); + __m256i acc10 = _mm256_setzero_si256(); + __m256i acc11 = _mm256_setzero_si256(); + __m256i acc12 = _mm256_setzero_si256(); + __m256i acc13 = _mm256_setzero_si256(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q0_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q0_32)); + __m256i q0_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q0_32, 1)); + __m256i q1_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q1_32)); + __m256i q1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q1_32, 1)); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k0_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k0_32)); + __m256i k0_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k0_32, 1)); + acc00 = _mm256_comp_dpwssd_epi32(acc00, q0_lo, k0_lo); + acc00 = _mm256_comp_dpwssd_epi32(acc00, q0_hi, k0_hi); + acc10 = _mm256_comp_dpwssd_epi32(acc10, q1_lo, k0_lo); + acc10 = _mm256_comp_dpwssd_epi32(acc10, q1_hi, k0_hi); + + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k1_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k1_32)); + __m256i k1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k1_32, 1)); + acc01 = _mm256_comp_dpwssd_epi32(acc01, q0_lo, k1_lo); + acc01 = _mm256_comp_dpwssd_epi32(acc01, q0_hi, k1_hi); + acc11 = _mm256_comp_dpwssd_epi32(acc11, q1_lo, k1_lo); + acc11 = _mm256_comp_dpwssd_epi32(acc11, q1_hi, k1_hi); + + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k2_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k2_32)); + __m256i k2_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k2_32, 1)); + acc02 = _mm256_comp_dpwssd_epi32(acc02, q0_lo, k2_lo); + acc02 = _mm256_comp_dpwssd_epi32(acc02, q0_hi, k2_hi); + acc12 = _mm256_comp_dpwssd_epi32(acc12, q1_lo, k2_lo); + acc12 = _mm256_comp_dpwssd_epi32(acc12, q1_hi, k2_hi); + + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m256i k3_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k3_32)); + __m256i k3_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k3_32, 1)); + acc03 = _mm256_comp_dpwssd_epi32(acc03, q0_lo, k3_lo); + acc03 = _mm256_comp_dpwssd_epi32(acc03, q0_hi, k3_hi); + acc13 = _mm256_comp_dpwssd_epi32(acc13, q1_lo, k3_lo); + acc13 = _mm256_comp_dpwssd_epi32(acc13, q1_hi, k3_hi); + } + else + { + int bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k2 + off, len); + scalar02 += bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k3 + off, len); + scalar03 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k2 + off, len); + scalar12 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k3 + off, len); + scalar13 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + S[(i + 0) * n + j] = ((float)_mm256_reduce_add_epi32(acc00) + (float)scalar00) * descale00 * scale; + S[(i + 0) * n + j + 1] = ((float)_mm256_reduce_add_epi32(acc01) + (float)scalar01) * descale01 * scale; + S[(i + 0) * n + j + 2] = ((float)_mm256_reduce_add_epi32(acc02) + (float)scalar02) * descale02 * scale; + S[(i + 0) * n + j + 3] = ((float)_mm256_reduce_add_epi32(acc03) + (float)scalar03) * descale03 * scale; + S[(i + 1) * n + j] = ((float)_mm256_reduce_add_epi32(acc10) + (float)scalar10) * descale10 * scale; + S[(i + 1) * n + j + 1] = ((float)_mm256_reduce_add_epi32(acc11) + (float)scalar11) * descale11 * scale; + S[(i + 1) * n + j + 2] = ((float)_mm256_reduce_add_epi32(acc12) + (float)scalar12) * descale12 * scale; + S[(i + 1) * n + j + 3] = ((float)_mm256_reduce_add_epi32(acc13) + (float)scalar13) * descale13 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + int sum00 = 0, sum01 = 0, sum10 = 0, sum11 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k0 + off, len); + sum00 += bs; + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k1 + off, len); + sum01 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k0 + off, len); + sum10 += bs; + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k1 + off, len); + sum11 += bs; + } + S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; + S[(i + 0) * n + j + 1] = (float)sum01 * qs0 * kscales[j + 1] * scale; + S[(i + 1) * n + j] = (float)sum10 * qs1 * kscales[j] * scale; + S[(i + 1) * n + j + 1] = (float)sum11 * qs1 * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_avx2_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_avx2_kernel(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx2_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void qk_int8_gemm_tiled_avx512_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 4 <= m; i += 4) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + const signed char* q2 = Q + (i + 2) * d; + const signed char* q3 = Q + (i + 3) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + float qs2 = qscales[i + 2]; + float qs3 = qscales[i + 3]; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc02 = _mm512_setzero_epi32(); + __m512i acc03 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc12 = _mm512_setzero_epi32(); + __m512i acc13 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc22 = _mm512_setzero_epi32(); + __m512i acc23 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + __m512i acc32 = _mm512_setzero_epi32(); + __m512i acc33 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + int scalar20 = 0, scalar21 = 0, scalar22 = 0, scalar23 = 0; + int scalar30 = 0, scalar31 = 0, scalar32 = 0, scalar33 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); + __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); + + acc00 = _mm512_comp_dpwssd_epi32(acc00, q0_512, k0_512); + acc01 = _mm512_comp_dpwssd_epi32(acc01, q0_512, k1_512); + acc02 = _mm512_comp_dpwssd_epi32(acc02, q0_512, k2_512); + acc03 = _mm512_comp_dpwssd_epi32(acc03, q0_512, k3_512); + acc10 = _mm512_comp_dpwssd_epi32(acc10, q1_512, k0_512); + acc11 = _mm512_comp_dpwssd_epi32(acc11, q1_512, k1_512); + acc12 = _mm512_comp_dpwssd_epi32(acc12, q1_512, k2_512); + acc13 = _mm512_comp_dpwssd_epi32(acc13, q1_512, k3_512); + acc20 = _mm512_comp_dpwssd_epi32(acc20, q2_512, k0_512); + acc21 = _mm512_comp_dpwssd_epi32(acc21, q2_512, k1_512); + acc22 = _mm512_comp_dpwssd_epi32(acc22, q2_512, k2_512); + acc23 = _mm512_comp_dpwssd_epi32(acc23, q2_512, k3_512); + acc30 = _mm512_comp_dpwssd_epi32(acc30, q3_512, k0_512); + acc31 = _mm512_comp_dpwssd_epi32(acc31, q3_512, k1_512); + acc32 = _mm512_comp_dpwssd_epi32(acc32, q3_512, k2_512); + acc33 = _mm512_comp_dpwssd_epi32(acc33, q3_512, k3_512); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k2 + off, len); + scalar02 += bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k3 + off, len); + scalar03 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k2 + off, len); + scalar12 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k3 + off, len); + scalar13 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k0 + off, len); + scalar20 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k1 + off, len); + scalar21 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k2 + off, len); + scalar22 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k3 + off, len); + scalar23 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k0 + off, len); + scalar30 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k1 + off, len); + scalar31 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k2 + off, len); + scalar32 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k3 + off, len); + scalar33 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale22 = qs2 * kscales[j + 2]; + float descale23 = qs2 * kscales[j + 3]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float descale32 = qs3 * kscales[j + 2]; + float descale33 = qs3 * kscales[j + 3]; + float sum00 = (float)_mm512_reduce_add_epi32(acc00) + (float)scalar00; + float sum01 = (float)_mm512_reduce_add_epi32(acc01) + (float)scalar01; + float sum02 = (float)_mm512_reduce_add_epi32(acc02) + (float)scalar02; + float sum03 = (float)_mm512_reduce_add_epi32(acc03) + (float)scalar03; + float sum10 = (float)_mm512_reduce_add_epi32(acc10) + (float)scalar10; + float sum11 = (float)_mm512_reduce_add_epi32(acc11) + (float)scalar11; + float sum12 = (float)_mm512_reduce_add_epi32(acc12) + (float)scalar12; + float sum13 = (float)_mm512_reduce_add_epi32(acc13) + (float)scalar13; + float sum20 = (float)_mm512_reduce_add_epi32(acc20) + (float)scalar20; + float sum21 = (float)_mm512_reduce_add_epi32(acc21) + (float)scalar21; + float sum22 = (float)_mm512_reduce_add_epi32(acc22) + (float)scalar22; + float sum23 = (float)_mm512_reduce_add_epi32(acc23) + (float)scalar23; + float sum30 = (float)_mm512_reduce_add_epi32(acc30) + (float)scalar30; + float sum31 = (float)_mm512_reduce_add_epi32(acc31) + (float)scalar31; + float sum32 = (float)_mm512_reduce_add_epi32(acc32) + (float)scalar32; + float sum33 = (float)_mm512_reduce_add_epi32(acc33) + (float)scalar33; + S[(i + 0) * n + j + 0] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 0) * n + j + 2] = sum02 * descale02 * scale; + S[(i + 0) * n + j + 3] = sum03 * descale03 * scale; + S[(i + 1) * n + j + 0] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 1) * n + j + 2] = sum12 * descale12 * scale; + S[(i + 1) * n + j + 3] = sum13 * descale13 * scale; + S[(i + 2) * n + j + 0] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 2) * n + j + 2] = sum22 * descale22 * scale; + S[(i + 2) * n + j + 3] = sum23 * descale23 * scale; + S[(i + 3) * n + j + 0] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + S[(i + 3) * n + j + 2] = sum32 * descale32 * scale; + S[(i + 3) * n + j + 3] = sum33 * descale33 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar10 = 0, scalar11 = 0; + int scalar20 = 0, scalar21 = 0, scalar30 = 0, scalar31 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + + acc00 = _mm512_comp_dpwssd_epi32(acc00, q0_512, k0_512); + acc01 = _mm512_comp_dpwssd_epi32(acc01, q0_512, k1_512); + acc10 = _mm512_comp_dpwssd_epi32(acc10, q1_512, k0_512); + acc11 = _mm512_comp_dpwssd_epi32(acc11, q1_512, k1_512); + acc20 = _mm512_comp_dpwssd_epi32(acc20, q2_512, k0_512); + acc21 = _mm512_comp_dpwssd_epi32(acc21, q2_512, k1_512); + acc30 = _mm512_comp_dpwssd_epi32(acc30, q3_512, k0_512); + acc31 = _mm512_comp_dpwssd_epi32(acc31, q3_512, k1_512); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k0 + off, len); + scalar20 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k1 + off, len); + scalar21 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k0 + off, len); + scalar30 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k1 + off, len); + scalar31 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float sum00 = (float)_mm512_reduce_add_epi32(acc00) + (float)scalar00; + float sum01 = (float)_mm512_reduce_add_epi32(acc01) + (float)scalar01; + float sum10 = (float)_mm512_reduce_add_epi32(acc10) + (float)scalar10; + float sum11 = (float)_mm512_reduce_add_epi32(acc11) + (float)scalar11; + float sum20 = (float)_mm512_reduce_add_epi32(acc20) + (float)scalar20; + float sum21 = (float)_mm512_reduce_add_epi32(acc21) + (float)scalar21; + float sum30 = (float)_mm512_reduce_add_epi32(acc30) + (float)scalar30; + float sum31 = (float)_mm512_reduce_add_epi32(acc31) + (float)scalar31; + S[(i + 0) * n + j] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 1) * n + j] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 2) * n + j] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 3) * n + j] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); + __m512i k_512 = _mm512_cvtepi8_epi16(k_32); + + acc0 = _mm512_comp_dpwssd_epi32(acc0, q0_512, k_512); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q1_512, k_512); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q2_512, k_512); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q3_512, k_512); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512_kernel(q0 + off, kptr + off, len); + scalar0 += bs; + bs = qk_int8_dot_block_avx512_kernel(q1 + off, kptr + off, len); + scalar1 += bs; + bs = qk_int8_dot_block_avx512_kernel(q2 + off, kptr + off, len); + scalar2 += bs; + bs = qk_int8_dot_block_avx512_kernel(q3 + off, kptr + off, len); + scalar3 += bs; + } + } + float descale = kscales[j]; + float sum0 = (float)_mm512_reduce_add_epi32(acc0) + (float)scalar0; + float sum1 = (float)_mm512_reduce_add_epi32(acc1) + (float)scalar1; + float sum2 = (float)_mm512_reduce_add_epi32(acc2) + (float)scalar2; + float sum3 = (float)_mm512_reduce_add_epi32(acc3) + (float)scalar3; + S[(i + 0) * n + j] = sum0 * qs0 * descale * scale; + S[(i + 1) * n + j] = sum1 * qs1 * descale * scale; + S[(i + 2) * n + j] = sum2 * qs2 * descale * scale; + S[(i + 3) * n + j] = sum3 * qs3 * descale * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx512_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX512F__ + +#if __AVX512VNNI__ +static inline void qk_int8_gemm_tiled_avx512vnni_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks_64 = d / 64; + int i = 0; + for (; i + 4 <= m; i += 4) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + const signed char* q2 = Q + (i + 2) * d; + const signed char* q3 = Q + (i + 3) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + float qs2 = qscales[i + 2]; + float qs3 = qscales[i + 3]; + + int qsum0_64byte = 0, qsum1_64byte = 0, qsum2_64byte = 0, qsum3_64byte = 0; + if (num_blocks_64 > 0) + { + __m512i ones = _mm512_set1_epi8(1); + __m512i sum0 = _mm512_setzero_epi32(); + __m512i sum1 = _mm512_setzero_epi32(); + __m512i sum2 = _mm512_setzero_epi32(); + __m512i sum3 = _mm512_setzero_epi32(); + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + sum0 = _mm512_dpbusd_epi32(sum0, ones, _mm512_loadu_si512((const __m512i*)(q0 + off))); + sum1 = _mm512_dpbusd_epi32(sum1, ones, _mm512_loadu_si512((const __m512i*)(q1 + off))); + sum2 = _mm512_dpbusd_epi32(sum2, ones, _mm512_loadu_si512((const __m512i*)(q2 + off))); + sum3 = _mm512_dpbusd_epi32(sum3, ones, _mm512_loadu_si512((const __m512i*)(q3 + off))); + } + qsum0_64byte = _mm512_reduce_add_epi32(sum0); + qsum1_64byte = _mm512_reduce_add_epi32(sum1); + qsum2_64byte = _mm512_reduce_add_epi32(sum2); + qsum3_64byte = _mm512_reduce_add_epi32(sum3); + } + + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc02 = _mm512_setzero_epi32(); + __m512i acc03 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc12 = _mm512_setzero_epi32(); + __m512i acc13 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc22 = _mm512_setzero_epi32(); + __m512i acc23 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + __m512i acc32 = _mm512_setzero_epi32(); + __m512i acc33 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + int scalar20 = 0, scalar21 = 0, scalar22 = 0, scalar23 = 0; + int scalar30 = 0, scalar31 = 0, scalar32 = 0, scalar33 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k2_512 = _mm512_loadu_si512((const __m512i*)(k2 + off)); + __m512i k3_512 = _mm512_loadu_si512((const __m512i*)(k3 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + __m512i k2_u8 = _mm512_add_epi8(k2_512, _mm512_set1_epi8(128)); + __m512i k3_u8 = _mm512_add_epi8(k3_512, _mm512_set1_epi8(128)); + + acc00 = _mm512_dpbusd_epi32(acc00, k0_u8, q0_512); + acc01 = _mm512_dpbusd_epi32(acc01, k1_u8, q0_512); + acc02 = _mm512_dpbusd_epi32(acc02, k2_u8, q0_512); + acc03 = _mm512_dpbusd_epi32(acc03, k3_u8, q0_512); + acc10 = _mm512_dpbusd_epi32(acc10, k0_u8, q1_512); + acc11 = _mm512_dpbusd_epi32(acc11, k1_u8, q1_512); + acc12 = _mm512_dpbusd_epi32(acc12, k2_u8, q1_512); + acc13 = _mm512_dpbusd_epi32(acc13, k3_u8, q1_512); + acc20 = _mm512_dpbusd_epi32(acc20, k0_u8, q2_512); + acc21 = _mm512_dpbusd_epi32(acc21, k1_u8, q2_512); + acc22 = _mm512_dpbusd_epi32(acc22, k2_u8, q2_512); + acc23 = _mm512_dpbusd_epi32(acc23, k3_u8, q2_512); + acc30 = _mm512_dpbusd_epi32(acc30, k0_u8, q3_512); + acc31 = _mm512_dpbusd_epi32(acc31, k1_u8, q3_512); + acc32 = _mm512_dpbusd_epi32(acc32, k2_u8, q3_512); + acc33 = _mm512_dpbusd_epi32(acc33, k3_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar00 += q0[k] * k0[k]; + scalar01 += q0[k] * k1[k]; + scalar02 += q0[k] * k2[k]; + scalar03 += q0[k] * k3[k]; + scalar10 += q1[k] * k0[k]; + scalar11 += q1[k] * k1[k]; + scalar12 += q1[k] * k2[k]; + scalar13 += q1[k] * k3[k]; + scalar20 += q2[k] * k0[k]; + scalar21 += q2[k] * k1[k]; + scalar22 += q2[k] * k2[k]; + scalar23 += q2[k] * k3[k]; + scalar30 += q3[k] * k0[k]; + scalar31 += q3[k] * k1[k]; + scalar32 += q3[k] * k2[k]; + scalar33 += q3[k] * k3[k]; + } + } + + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale22 = qs2 * kscales[j + 2]; + float descale23 = qs2 * kscales[j + 3]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float descale32 = qs3 * kscales[j + 2]; + float descale33 = qs3 * kscales[j + 3]; + float sum00 = (float)(_mm512_reduce_add_epi32(acc00) - 128 * qsum0_64byte) + (float)scalar00; + float sum01 = (float)(_mm512_reduce_add_epi32(acc01) - 128 * qsum0_64byte) + (float)scalar01; + float sum02 = (float)(_mm512_reduce_add_epi32(acc02) - 128 * qsum0_64byte) + (float)scalar02; + float sum03 = (float)(_mm512_reduce_add_epi32(acc03) - 128 * qsum0_64byte) + (float)scalar03; + float sum10 = (float)(_mm512_reduce_add_epi32(acc10) - 128 * qsum1_64byte) + (float)scalar10; + float sum11 = (float)(_mm512_reduce_add_epi32(acc11) - 128 * qsum1_64byte) + (float)scalar11; + float sum12 = (float)(_mm512_reduce_add_epi32(acc12) - 128 * qsum1_64byte) + (float)scalar12; + float sum13 = (float)(_mm512_reduce_add_epi32(acc13) - 128 * qsum1_64byte) + (float)scalar13; + float sum20 = (float)(_mm512_reduce_add_epi32(acc20) - 128 * qsum2_64byte) + (float)scalar20; + float sum21 = (float)(_mm512_reduce_add_epi32(acc21) - 128 * qsum2_64byte) + (float)scalar21; + float sum22 = (float)(_mm512_reduce_add_epi32(acc22) - 128 * qsum2_64byte) + (float)scalar22; + float sum23 = (float)(_mm512_reduce_add_epi32(acc23) - 128 * qsum2_64byte) + (float)scalar23; + float sum30 = (float)(_mm512_reduce_add_epi32(acc30) - 128 * qsum3_64byte) + (float)scalar30; + float sum31 = (float)(_mm512_reduce_add_epi32(acc31) - 128 * qsum3_64byte) + (float)scalar31; + float sum32 = (float)(_mm512_reduce_add_epi32(acc32) - 128 * qsum3_64byte) + (float)scalar32; + float sum33 = (float)(_mm512_reduce_add_epi32(acc33) - 128 * qsum3_64byte) + (float)scalar33; + S[(i + 0) * n + j + 0] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 0) * n + j + 2] = sum02 * descale02 * scale; + S[(i + 0) * n + j + 3] = sum03 * descale03 * scale; + S[(i + 1) * n + j + 0] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 1) * n + j + 2] = sum12 * descale12 * scale; + S[(i + 1) * n + j + 3] = sum13 * descale13 * scale; + S[(i + 2) * n + j + 0] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 2) * n + j + 2] = sum22 * descale22 * scale; + S[(i + 2) * n + j + 3] = sum23 * descale23 * scale; + S[(i + 3) * n + j + 0] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + S[(i + 3) * n + j + 2] = sum32 * descale32 * scale; + S[(i + 3) * n + j + 3] = sum33 * descale33 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar10 = 0, scalar11 = 0; + int scalar20 = 0, scalar21 = 0, scalar30 = 0, scalar31 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + + acc00 = _mm512_dpbusd_epi32(acc00, k0_u8, q0_512); + acc01 = _mm512_dpbusd_epi32(acc01, k1_u8, q0_512); + acc10 = _mm512_dpbusd_epi32(acc10, k0_u8, q1_512); + acc11 = _mm512_dpbusd_epi32(acc11, k1_u8, q1_512); + acc20 = _mm512_dpbusd_epi32(acc20, k0_u8, q2_512); + acc21 = _mm512_dpbusd_epi32(acc21, k1_u8, q2_512); + acc30 = _mm512_dpbusd_epi32(acc30, k0_u8, q3_512); + acc31 = _mm512_dpbusd_epi32(acc31, k1_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar00 += q0[k] * k0[k]; + scalar01 += q0[k] * k1[k]; + scalar10 += q1[k] * k0[k]; + scalar11 += q1[k] * k1[k]; + scalar20 += q2[k] * k0[k]; + scalar21 += q2[k] * k1[k]; + scalar30 += q3[k] * k0[k]; + scalar31 += q3[k] * k1[k]; + } + } + + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float sum00 = (float)(_mm512_reduce_add_epi32(acc00) - 128 * qsum0_64byte) + (float)scalar00; + float sum01 = (float)(_mm512_reduce_add_epi32(acc01) - 128 * qsum0_64byte) + (float)scalar01; + float sum10 = (float)(_mm512_reduce_add_epi32(acc10) - 128 * qsum1_64byte) + (float)scalar10; + float sum11 = (float)(_mm512_reduce_add_epi32(acc11) - 128 * qsum1_64byte) + (float)scalar11; + float sum20 = (float)(_mm512_reduce_add_epi32(acc20) - 128 * qsum2_64byte) + (float)scalar20; + float sum21 = (float)(_mm512_reduce_add_epi32(acc21) - 128 * qsum2_64byte) + (float)scalar21; + float sum30 = (float)(_mm512_reduce_add_epi32(acc30) - 128 * qsum3_64byte) + (float)scalar30; + float sum31 = (float)(_mm512_reduce_add_epi32(acc31) - 128 * qsum3_64byte) + (float)scalar31; + S[(i + 0) * n + j] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 1) * n + j] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 2) * n + j] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 3) * n + j] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k_512 = _mm512_loadu_si512((const __m512i*)(kptr + off)); + __m512i k_u8 = _mm512_add_epi8(k_512, _mm512_set1_epi8(128)); + + acc0 = _mm512_dpbusd_epi32(acc0, k_u8, q0_512); + acc1 = _mm512_dpbusd_epi32(acc1, k_u8, q1_512); + acc2 = _mm512_dpbusd_epi32(acc2, k_u8, q2_512); + acc3 = _mm512_dpbusd_epi32(acc3, k_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar0 += q0[k] * kptr[k]; + scalar1 += q1[k] * kptr[k]; + scalar2 += q2[k] * kptr[k]; + scalar3 += q3[k] * kptr[k]; + } + } + + float descale = kscales[j]; + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum0_64byte) + (float)scalar0; + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum1_64byte) + (float)scalar1; + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum2_64byte) + (float)scalar2; + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum3_64byte) + (float)scalar3; + S[(i + 0) * n + j] = sum0 * qs0 * descale * scale; + S[(i + 1) * n + j] = sum1 * qs1 * descale * scale; + S[(i + 2) * n + j] = sum2 * qs2 * descale * scale; + S[(i + 3) * n + j] = sum3 * qs3 * descale * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx512vnni_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX512VNNI__ + +static void qk_int8_gemm_tiled(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_tiled_avx512vnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if __AVX512VNNI__ + qk_int8_gemm_tiled_avx512vnni_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#else + qk_int8_gemm_tiled_avx512_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#endif +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_tiled_avx512vnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + qk_int8_gemm_tiled_avxvnniint8(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + qk_int8_gemm_tiled_avxvnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + qk_int8_gemm_tiled_avx2(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if __AVX2__ + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#elif __SSE2__ + qk_int8_gemm_tiled_sse2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#else + qk_int8_gemm_tiled_scalar_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#endif +#endif +} + +// ------------------- Decode PV GEMV Int8 ------------------- + +static inline void decode_pv_gemv_int8_scalar_kernel(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + for (int k = k_start; k < k_end; k++) + out[k] = 0.f; + + for (int j = 0; j < block_n; j++) + { + float p = s[j]; + float inv_scale = 1.f / vscales[(n_start + j) * num_blocks + vb]; + const signed char* vptr = V + (n_start + j) * out_d + k_start; + for (int k = k_start; k < k_end; k++) + out[k] += p * (float)vptr[k - k_start] * inv_scale; + } + } +} + +#if __SSE2__ +static inline void decode_pv_gemv_int8_sse2_kernel(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + V[(n_start + j) * out_d + k + 3], V[(n_start + j) * out_d + k + 2], + V[(n_start + j) * out_d + k + 1], V[(n_start + j) * out_d + k + 0]); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), oval)); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + sum += p_invscale * V[(n_start + j) * out_d + k]; + } + out[k] += sum; + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void decode_pv_gemv_int8_avx2_kernel(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + + int k = k_start; + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, _mm256_add_ps(_mm256_loadu_ps(out + k), oval)); + } + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + (n_start + j) * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < block_n; j++) + sum += s[j] / vscales[(n_start + j) * num_blocks + vb] * V[(n_start + j) * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX__ + +#if __AVX512F__ +static inline void decode_pv_gemv_int8_avx512_kernel(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + float p0 = s[j]; + float p1 = s[j + 1]; + const signed char* v0 = V + (n_start + j) * out_d; + const signed char* v1 = V + (n_start + j + 1) * out_d; + + int k = 0; + for (; k + 31 < out_d; k += 32) + { + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + + __m128i v8_0a = _mm_loadu_si128((const __m128i*)(v0 + k)); + __m128i v8_0b = _mm_loadu_si128((const __m128i*)(v0 + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0a)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0b)), oval1); + + __m128i v8_1a = _mm_loadu_si128((const __m128i*)(v1 + k)); + __m128i v8_1b = _mm_loadu_si128((const __m128i*)(v1 + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1a)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1b)), oval1); + + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + for (; k + 15 < out_d; k += 16) + { + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + + __m512 oval = _mm512_loadu_ps(out + k); + + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(v0 + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), oval); + + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(v1 + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), oval); + + _mm512_storeu_ps(out + k, oval); + } + for (; k < out_d; k++) + { + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + out[k] += p0_invscale * v0[k] + p1_invscale * v1[k]; + } + } + for (; j < block_n; j++) + { + float p = s[j]; + const signed char* v = V + (n_start + j) * out_d; + + int k = 0; + for (; k + 31 < out_d; k += 32) + { + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(v + k)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(v + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), oval1); + + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + for (; k + 15 < out_d; k += 16) + { + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + + __m512 oval = _mm512_loadu_ps(out + k); + __m128i v8 = _mm_loadu_si128((const __m128i*)(v + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8)), oval); + _mm512_storeu_ps(out + k, oval); + } + for (; k < out_d; k++) + { + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + out[k] += p_invscale * v[k]; + } + } +} +#endif // __AVX512F__ + +static void decode_pv_gemv_int8(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ +#if __AVX512F__ + decode_pv_gemv_int8_avx512_kernel(out, s, V, vscales, n_start, block_n, out_d); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + decode_pv_gemv_int8_avx2(out, s, V, vscales, n_start, block_n, out_d); + return; + } +#endif +#if __AVX2__ + decode_pv_gemv_int8_avx2(out, s, V, vscales, n_start, block_n, out_d); +#elif __SSE2__ + decode_pv_gemv_int8_sse2_kernel(out, s, V, vscales, n_start, block_n, out_d); +#else + decode_pv_gemv_int8_scalar_kernel(out, s, V, vscales, n_start, block_n, out_d); +#endif +#endif +} + +// ------------------- Prefill PV Float×Int8 GEMM Row-wise ------------------- + +static inline void pv_float_int8_gemm_row_scalar_kernel(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + for (int k = k_start; k < k_end; k++) + out[k] = 0.f; + + for (int j = 0; j < n; j++) + { + float p = p_row[j]; + float inv_scale = 1.f / vscales[j * num_blocks + vb]; + const signed char* vptr = V + j * out_d + k_start; + for (int k = k_start; k < k_end; k++) + out[k] += p * (float)vptr[k - k_start] * inv_scale; + } + } +} + +#if __SSE2__ +static inline void pv_float_int8_gemm_row_sse2_kernel(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + j * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void pv_float_int8_gemm_row_avx2_kernel(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + j * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + j * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void pv_float_int8_gemm_row_avx512_kernel(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 15 < k_end; k += 16) + { + __m512 oval = _mm512_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m512 pvec = _mm512_set1_ps(p_invscale); + __m128i v8 = _mm_loadu_si128((const __m128i*)(V + j * out_d + k)); + __m512i v32 = _mm512_cvtepi8_epi32(v8); + __m512 vfp = _mm512_cvtepi32_ps(v32); + oval = _mm512_fmadd_ps(pvec, vfp, oval); + } + _mm512_storeu_ps(out + k, oval); + } + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + j * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX512F__ + +static void pv_float_int8_gemm_row(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ +#if __AVX512F__ + pv_float_int8_gemm_row_avx512_kernel(out, p_row, V, vscales, n, out_d); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_gemm_row_avx2(out, p_row, V, vscales, n, out_d); + return; + } +#endif +#if __AVX2__ + pv_float_int8_gemm_row_avx2(out, p_row, V, vscales, n, out_d); +#elif __SSE2__ + pv_float_int8_gemm_row_sse2_kernel(out, p_row, V, vscales, n, out_d); +#else + pv_float_int8_gemm_row_scalar_kernel(out, p_row, V, vscales, n, out_d); +#endif +#endif +} + +// ------------------- PV Float×Int8 FMA Block (online softmax) ------------------- + +static inline void pv_float_int8_fma_block_scalar_kernel(float* out, float p_invscale, const signed char* v, int len) +{ + for (int k = 0; k < len; k++) + out[k] += p_invscale * v[k]; +} + +#if __SSE2__ +static inline void pv_float_int8_fma_block_sse2_kernel(float* out, float p_invscale, const signed char* v, int len) +{ + __m128 pvec = _mm_set1_ps(p_invscale); + int k = 0; + for (; k + 3 < len; k += 4) + { + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(v + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), _mm_mul_ps(pvec, vfp))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __SSE2__ + +#if __AVX2__ +inline void pv_float_int8_fma_block_avx2_kernel(float* out, float p_invscale, const signed char* v, int len) +{ + __m256 pvec = _mm256_set1_ps(p_invscale); + int k = 0; + for (; k + 7 < len; k += 8) + { + __m128i v8 = _mm_loadl_epi64((const __m128i*)(v + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + _mm256_storeu_ps(out + k, _mm256_fmadd_ps(pvec, vfp, _mm256_loadu_ps(out + k))); + } + for (; k + 3 < len; k += 4) + { + __m128 pvec128 = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(v + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), _mm_mul_ps(pvec128, vfp))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void pv_float_int8_fma_block_avx512_kernel(float* out, float p_invscale, const signed char* v, int len) +{ + __m512 pvec = _mm512_set1_ps(p_invscale); + int k = 0; + for (; k + 15 < len; k += 16) + { + __m128i v8 = _mm_loadu_si128((const __m128i*)(v + k)); + __m512i v32 = _mm512_cvtepi8_epi32(v8); + __m512 vfp = _mm512_cvtepi32_ps(v32); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vfp, _mm512_loadu_ps(out + k))); + } + for (; k + 7 < len; k += 8) + { + __m256 pvec256 = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(v + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + _mm256_storeu_ps(out + k, _mm256_fmadd_ps(pvec256, vfp, _mm256_loadu_ps(out + k))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __AVX512F__ + +static void pv_float_int8_fma_block(float* out, float p_invscale, const signed char* v, int len) +{ +#if __AVX512F__ + pv_float_int8_fma_block_avx512_kernel(out, p_invscale, v, len); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_fma_block_avx2(out, p_invscale, v, len); + return; + } +#endif +#if __AVX2__ + pv_float_int8_fma_block_avx2(out, p_invscale, v, len); +#elif __SSE2__ + pv_float_int8_fma_block_sse2_kernel(out, p_invscale, v, len); +#else + pv_float_int8_fma_block_scalar_kernel(out, p_invscale, v, len); +#endif +#endif +} + +#if __AVX2__ +inline void pv_float_int8_gemm_tile_avx2_kernel(float* O, const float* P, + const signed char* V, const float* vscales, + int block_m, int block_n, int out_embed_dim) +{ + const int v_num_blocks = (out_embed_dim + 31) / 32; + int i = 0; + for (; i + 1 < block_m; i += 2) + { + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m256 acc0_0 = _mm256_setzero_ps(); + __m256 acc0_1 = _mm256_setzero_ps(); + __m256 acc0_2 = _mm256_setzero_ps(); + __m256 acc0_3 = _mm256_setzero_ps(); + __m256 acc1_0 = _mm256_setzero_ps(); + __m256 acc1_1 = _mm256_setzero_ps(); + __m256 acc1_2 = _mm256_setzero_ps(); + __m256 acc1_3 = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m256 vscale8 = _mm256_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); + __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); + __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); + __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); + __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); + __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); + __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); + __m256 pvec0 = _mm256_set1_ps(p0[j]); + __m256 pvec1 = _mm256_set1_ps(p1[j]); + acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); + acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); + acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); + acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); + acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); + acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); + _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); + _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); + _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); + _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); + _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); + _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); + _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + } + } + } + } + } + for (; i < block_m; i++) + { + const float* p_row = P + i * block_n; + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + } + } + } + } +} +#endif + +static void pv_float_int8_gemm_tile(float* O, const float* P, + const signed char* V, const float* vscales, + int block_m, int block_n, int out_embed_dim) +{ + const int v_num_blocks = (out_embed_dim + 31) / 32; +#if __AVX512F__ + int i = 0; + for (; i + 3 < block_m; i += 4) + { + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + const float* p2 = P + (i + 2) * block_n; + const float* p3 = P + (i + 3) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m512 acc0_0 = _mm512_setzero_ps(); + __m512 acc0_1 = _mm512_setzero_ps(); + __m512 acc1_0 = _mm512_setzero_ps(); + __m512 acc1_1 = _mm512_setzero_ps(); + __m512 acc2_0 = _mm512_setzero_ps(); + __m512 acc2_1 = _mm512_setzero_ps(); + __m512 acc3_0 = _mm512_setzero_ps(); + __m512 acc3_1 = _mm512_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m512 vscale16 = _mm512_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(vptr + 16)); + __m512 vval_0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), vscale16); + __m512 vval_1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), vscale16); + __m512 pvec0 = _mm512_set1_ps(p0[j]); + __m512 pvec1 = _mm512_set1_ps(p1[j]); + __m512 pvec2 = _mm512_set1_ps(p2[j]); + __m512 pvec3 = _mm512_set1_ps(p3[j]); + acc0_0 = _mm512_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm512_fmadd_ps(pvec0, vval_1, acc0_1); + acc1_0 = _mm512_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm512_fmadd_ps(pvec1, vval_1, acc1_1); + acc2_0 = _mm512_fmadd_ps(pvec2, vval_0, acc2_0); + acc2_1 = _mm512_fmadd_ps(pvec2, vval_1, acc2_1); + acc3_0 = _mm512_fmadd_ps(pvec3, vval_0, acc3_0); + acc3_1 = _mm512_fmadd_ps(pvec3, vval_1, acc3_1); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + float* optr2 = O + (i + 2) * out_embed_dim + k_start; + float* optr3 = O + (i + 3) * out_embed_dim + k_start; + _mm512_storeu_ps(optr0 + 0, _mm512_add_ps(_mm512_loadu_ps(optr0 + 0), acc0_0)); + _mm512_storeu_ps(optr0 + 16, _mm512_add_ps(_mm512_loadu_ps(optr0 + 16), acc0_1)); + _mm512_storeu_ps(optr1 + 0, _mm512_add_ps(_mm512_loadu_ps(optr1 + 0), acc1_0)); + _mm512_storeu_ps(optr1 + 16, _mm512_add_ps(_mm512_loadu_ps(optr1 + 16), acc1_1)); + _mm512_storeu_ps(optr2 + 0, _mm512_add_ps(_mm512_loadu_ps(optr2 + 0), acc2_0)); + _mm512_storeu_ps(optr2 + 16, _mm512_add_ps(_mm512_loadu_ps(optr2 + 16), acc2_1)); + _mm512_storeu_ps(optr3 + 0, _mm512_add_ps(_mm512_loadu_ps(optr3 + 0), acc3_0)); + _mm512_storeu_ps(optr3 + 16, _mm512_add_ps(_mm512_loadu_ps(optr3 + 16), acc3_1)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + O[(i + 2) * out_embed_dim + k_start + k] += p2[j] * vval; + O[(i + 3) * out_embed_dim + k_start + k] += p3[j] * vval; + } + } + } + } + } + for (; i < block_m; i++) + { + const float* p_row = P + i * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m512 vscale16 = _mm512_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(vptr + 16)); + __m512 vval_0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), vscale16); + __m512 vval_1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), vscale16); + __m512 pvec = _mm512_set1_ps(p_row[j]); + acc0 = _mm512_fmadd_ps(pvec, vval_0, acc0); + acc1 = _mm512_fmadd_ps(pvec, vval_1, acc1); + } + float* optr = O + i * out_embed_dim + k_start; + _mm512_storeu_ps(optr + 0, _mm512_add_ps(_mm512_loadu_ps(optr + 0), acc0)); + _mm512_storeu_ps(optr + 16, _mm512_add_ps(_mm512_loadu_ps(optr + 16), acc1)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k] * vscale; + } + } + } + } + } +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_gemm_tile_avx2(O, P, V, vscales, block_m, block_n, out_embed_dim); + return; + } +#endif +#if __AVX2__ + int i = 0; + for (; i + 1 < block_m; i += 2) + { + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m256 acc0_0 = _mm256_setzero_ps(); + __m256 acc0_1 = _mm256_setzero_ps(); + __m256 acc0_2 = _mm256_setzero_ps(); + __m256 acc0_3 = _mm256_setzero_ps(); + __m256 acc1_0 = _mm256_setzero_ps(); + __m256 acc1_1 = _mm256_setzero_ps(); + __m256 acc1_2 = _mm256_setzero_ps(); + __m256 acc1_3 = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m256 vscale8 = _mm256_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); + __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); + __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); + __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); + __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); + __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); + __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); + __m256 pvec0 = _mm256_set1_ps(p0[j]); + __m256 pvec1 = _mm256_set1_ps(p1[j]); + acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); + acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); + acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); + acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); + acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); + acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); + _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); + _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); + _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); + _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); + _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); + _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); + _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + } + } + } + } + } + for (; i < block_m; i++) + { + const float* p_row = P + i * block_n; + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + } + } + } + } +#elif __SSE2__ + for (int j = 0; j < block_n; j++) + { + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(vptr + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + __m128 vval = _mm_mul_ps(vfp, _mm_set1_ps(vscale)); + for (int ii = 0; ii < block_m; ii++) + { + float p = P[ii * block_n + j]; + __m128 pvec = _mm_set1_ps(p); + float* optr = O + ii * out_embed_dim + k; + _mm_storeu_ps(optr, _mm_add_ps(_mm_loadu_ps(optr), _mm_mul_ps(pvec, vval))); + } + } + for (; k < k_end; k++) + { + float vval = (float)vptr[k] * vscale; + for (int ii = 0; ii < block_m; ii++) + { + O[ii * out_embed_dim + k] += P[ii * block_n + j] * vval; + } + } + } + } +#else + for (int j = 0; j < block_n; j++) + { + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = k_start; k < k_end; k++) + { + float vval = (float)vptr[k] * vscale; + for (int ii = 0; ii < block_m; ii++) + { + O[ii * out_embed_dim + k] += P[ii * block_n + j] * vval; + } + } + } + } +#endif +#endif +} + +#if __SSE2__ && !__SSSE3__ +#undef _mm_cvtepi8_epi16 +#endif +#if __SSE2__ && !__SSE4_1__ +#undef _mm_cvtepi8_epi32 +#endif + +#endif // SDPA_X86_INT8_H diff --git a/src/layer/x86/sdpa_x86_xop.cpp b/src/layer/x86/sdpa_x86_xop.cpp new file mode 100644 index 000000000000..9f0216a47b64 --- /dev/null +++ b/src/layer/x86/sdpa_x86_xop.cpp @@ -0,0 +1,56 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __XOP__ + +void decode_qk_dot_int8_xop(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_xop_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_xop(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_xop_kernel(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} + +void qk_int8_gemm_tiled_xop(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + qk_int8_gemm_row_xop(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} + +#endif // __XOP__ + +} // namespace ncnn diff --git a/tests/perf/perf_sdpa_decode.cpp b/tests/perf/perf_sdpa_decode.cpp index c670fa53c671..60f39932e421 100644 --- a/tests/perf/perf_sdpa_decode.cpp +++ b/tests/perf/perf_sdpa_decode.cpp @@ -29,6 +29,16 @@ static void perf_sdpa_decode(int embed_dim, int num_heads, int num_groups, int p perf_layer("SDPA", pd, weights, inputs, 3, "embed=%d heads=%d groups=%d past=%d", embed_dim, num_heads, num_groups, past_seqlen); + + // int8 variant + ncnn::ParamDict pd_int8; + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 1); // kv_cache = 1 + pd_int8.set(18, 2); // int8_scale_term + perf_layer_int8("SDPA", pd_int8, weights, inputs, 3, + "embed=%d heads=%d groups=%d past=%d", + embed_dim, num_heads, num_groups, past_seqlen); } int main() @@ -52,33 +62,20 @@ int main() // larger model (e.g., 7B scale) perf_sdpa_decode(4096, 32, 32, 0); - perf_sdpa_decode(4096, 32, 32, 128); - perf_sdpa_decode(4096, 32, 32, 512); - perf_sdpa_decode(4096, 32, 32, 1024); - perf_sdpa_decode(4096, 32, 32, 2048); - perf_sdpa_decode(4096, 32, 32, 4096); - perf_sdpa_decode(4096, 32, 32, 8192); + perf_sdpa_decode(4096, 32, 32, 16); + perf_sdpa_decode(4096, 32, 32, 32); + perf_sdpa_decode(4096, 32, 32, 64); // GQA/MQA configurations // GQA: num_groups < num_heads - perf_sdpa_decode(4096, 32, 4, 128); - perf_sdpa_decode(4096, 32, 4, 512); - perf_sdpa_decode(4096, 32, 4, 1024); - perf_sdpa_decode(4096, 32, 4, 2048); - perf_sdpa_decode(4096, 32, 4, 4096); + perf_sdpa_decode(4096, 32, 4, 16); + perf_sdpa_decode(4096, 32, 4, 32); + perf_sdpa_decode(4096, 32, 4, 64); // MQA: num_groups = 1 - perf_sdpa_decode(4096, 32, 1, 128); - perf_sdpa_decode(4096, 32, 1, 512); - perf_sdpa_decode(4096, 32, 1, 1024); - perf_sdpa_decode(4096, 32, 1, 2048); - perf_sdpa_decode(4096, 32, 1, 4096); - - // very large context lengths - perf_sdpa_decode(4096, 32, 32, 16384); - perf_sdpa_decode(4096, 32, 32, 32768); - perf_sdpa_decode(4096, 32, 4, 16384); - perf_sdpa_decode(4096, 32, 4, 32768); + perf_sdpa_decode(4096, 32, 1, 16); + perf_sdpa_decode(4096, 32, 1, 32); + perf_sdpa_decode(4096, 32, 1, 64); return 0; } diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index 6b5c5b08e306..02ef9cd8c892 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -3,9 +3,68 @@ #include "perfutil.h" +#include +#include + +static bool match_env_int(const char* name, int value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return atoi(s) == value; +} + +static bool has_env_int(const char* name) +{ + const char* s = getenv(name); + return s && s[0]; +} + +static bool match_env_string(const char* name, const char* value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return strcmp(s, value) == 0; +} + +static bool should_run_fp_prefill() +{ + return match_env_string("NCNN_PERF_DTYPE", "fp32") + || match_env_string("NCNN_PERF_DTYPE", "fp16ps") + || match_env_string("NCNN_PERF_DTYPE", "fp16psa") + || match_env_string("NCNN_PERF_DTYPE", "bf16ps"); +} + +static bool should_run_int8_prefill() +{ + return match_env_string("NCNN_PERF_DTYPE", "int8"); +} + +static bool should_run_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) +{ + return match_env_int("NCNN_PERF_SDPA_EMBED", embed_dim) + && match_env_int("NCNN_PERF_SDPA_HEADS", num_heads) + && match_env_int("NCNN_PERF_SDPA_GROUPS", num_groups) + && match_env_int("NCNN_PERF_SDPA_SEQLEN", src_seqlen); +} + +static bool should_run_extended_prefill() +{ + return has_env_int("NCNN_PERF_SDPA_EMBED") + && has_env_int("NCNN_PERF_SDPA_HEADS") + && has_env_int("NCNN_PERF_SDPA_GROUPS") + && has_env_int("NCNN_PERF_SDPA_SEQLEN"); +} + // prefill phase: larger src_seqlen, no kv_cache (past_seqlen=0) static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) { + if (!should_run_prefill(embed_dim, num_heads, num_groups, src_seqlen)) + return; + const int cur_seqlen = src_seqlen; // in prefill, cur_seqlen == src_seqlen const int out_embed_dim = embed_dim; @@ -22,9 +81,25 @@ static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int inputs[1] = PerfMat(embed_dim, cur_seqlen, num_groups); // k inputs[2] = PerfMat(out_embed_dim, cur_seqlen, num_groups); // v - perf_layer("SDPA", pd, weights, inputs, 1, - "embed=%d heads=%d groups=%d seqlen=%d", - embed_dim, num_heads, num_groups, src_seqlen); + if (should_run_fp_prefill()) + { + perf_layer("SDPA", pd, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); + } + + // int8 variant + ncnn::ParamDict pd_int8; + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 0); // kv_cache = 0 + pd_int8.set(18, 2); // int8_scale_term + if (should_run_int8_prefill()) + { + perf_layer_int8("SDPA", pd_int8, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); + } } int main() @@ -53,37 +128,38 @@ int main() perf_sdpa_prefill(4096, 32, 32, 16); perf_sdpa_prefill(4096, 32, 32, 32); perf_sdpa_prefill(4096, 32, 32, 64); - perf_sdpa_prefill(4096, 32, 32, 128); - perf_sdpa_prefill(4096, 32, 32, 256); - perf_sdpa_prefill(4096, 32, 32, 512); - perf_sdpa_prefill(4096, 32, 32, 1024); - perf_sdpa_prefill(4096, 32, 32, 2048); - perf_sdpa_prefill(4096, 32, 32, 4096); // GQA/MQA configurations // GQA: num_groups < num_heads - perf_sdpa_prefill(4096, 32, 4, 128); - perf_sdpa_prefill(4096, 32, 4, 256); - perf_sdpa_prefill(4096, 32, 4, 512); - perf_sdpa_prefill(4096, 32, 4, 1024); - perf_sdpa_prefill(4096, 32, 4, 2048); - perf_sdpa_prefill(4096, 32, 4, 4096); + perf_sdpa_prefill(4096, 32, 16, 16); + perf_sdpa_prefill(4096, 32, 16, 32); + perf_sdpa_prefill(4096, 32, 16, 64); + perf_sdpa_prefill(4096, 32, 8, 16); + perf_sdpa_prefill(4096, 32, 8, 32); + perf_sdpa_prefill(4096, 32, 8, 64); + perf_sdpa_prefill(4096, 32, 4, 16); + perf_sdpa_prefill(4096, 32, 4, 32); + perf_sdpa_prefill(4096, 32, 4, 64); // MQA: num_groups = 1 - perf_sdpa_prefill(4096, 32, 1, 128); - perf_sdpa_prefill(4096, 32, 1, 256); - perf_sdpa_prefill(4096, 32, 1, 512); - perf_sdpa_prefill(4096, 32, 1, 1024); - perf_sdpa_prefill(4096, 32, 1, 2048); - perf_sdpa_prefill(4096, 32, 1, 4096); - - // very long sequences - perf_sdpa_prefill(4096, 32, 32, 8192); - perf_sdpa_prefill(4096, 32, 32, 16384); - perf_sdpa_prefill(4096, 32, 32, 32768); - perf_sdpa_prefill(4096, 32, 4, 8192); - perf_sdpa_prefill(4096, 32, 4, 16384); - perf_sdpa_prefill(4096, 32, 4, 32768); + perf_sdpa_prefill(4096, 32, 1, 16); + perf_sdpa_prefill(4096, 32, 1, 32); + perf_sdpa_prefill(4096, 32, 1, 64); + + // Longer large-model cases are opt-in through NCNN_PERF_SDPA_* filters. + if (should_run_extended_prefill()) + { + perf_sdpa_prefill(4096, 32, 32, 128); + perf_sdpa_prefill(4096, 32, 32, 256); + perf_sdpa_prefill(4096, 32, 16, 128); + perf_sdpa_prefill(4096, 32, 16, 256); + perf_sdpa_prefill(4096, 32, 8, 128); + perf_sdpa_prefill(4096, 32, 8, 256); + perf_sdpa_prefill(4096, 32, 4, 128); + perf_sdpa_prefill(4096, 32, 4, 256); + perf_sdpa_prefill(4096, 32, 1, 128); + perf_sdpa_prefill(4096, 32, 1, 256); + } return 0; } diff --git a/tests/perf/perfutil.cpp b/tests/perf/perfutil.cpp index 17bc0e62326d..61a58a441533 100644 --- a/tests/perf/perfutil.cpp +++ b/tests/perf/perfutil.cpp @@ -11,8 +11,13 @@ #include #include #include +#include #include +#if defined(__linux__) +#include +#endif + #if NCNN_VULKAN #include "command.h" #include "gpu.h" @@ -25,6 +30,60 @@ #define PERF_RUN_COUNT 20 #define PERF_TARGET_MIN_MS 5.0 +static int perf_env_int(const char* name, int default_value, int min_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + + int v = atoi(s); + return v < min_value ? min_value : v; +} + +static double perf_env_double(const char* name, double default_value, double min_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + + double v = atof(s); + return v < min_value ? min_value : v; +} + +static void setup_perf_cpu_affinity() +{ + static bool initialized = false; + if (initialized) + return; + initialized = true; + +#if defined(__linux__) + const char* s = getenv("NCNN_PERF_CPU_AFFINITY"); + if (!s || !s[0]) + return; + + int cpu = atoi(s); + if (cpu < 0) + return; + if (cpu >= CPU_SETSIZE) + { + fprintf(stderr, "NCNN_PERF_CPU_AFFINITY=%d exceeds CPU_SETSIZE=%d\n", cpu, CPU_SETSIZE); + return; + } + + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpu, &cpuset); + + if (sched_setaffinity(0, sizeof(cpuset), &cpuset) != 0) + { + fprintf(stderr, "NCNN_PERF_CPU_AFFINITY=%d sched_setaffinity failed\n", cpu); + } +#else + (void)getenv("NCNN_PERF_CPU_AFFINITY"); +#endif +} + // benchmark result for a single test case struct PerfResult { @@ -267,6 +326,12 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, int forced_inner_loops, PerfResult& result) { + setup_perf_cpu_affinity(); + + const int warmup_count = perf_env_int("NCNN_PERF_WARMUP_COUNT", PERF_WARMUP_COUNT, 1); + const int run_count = perf_env_int("NCNN_PERF_RUN_COUNT", PERF_RUN_COUNT, 1); + const double target_min_ms = perf_env_double("NCNN_PERF_TARGET_MIN_MS", PERF_TARGET_MIN_MS, 0.0); + ncnn::Layer* op = ncnn::create_layer_cpu(layer_type); if (!op) { @@ -290,7 +355,7 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, // warmup and calibrate inner loop count from warmup min time double warmup_min_ms = DBL_MAX; - for (int i = 0; i < PERF_WARMUP_COUNT; i++) + for (int i = 0; i < warmup_count; i++) { double t0 = ncnn::get_current_time(); int ret = run_layer_forward_cpu(op, converted, top_blob_count, opt); @@ -314,19 +379,19 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, { // calibrate inner loop count to power of 10, so total time >= PERF_TARGET_MIN_MS inner_loops = 1; - if (warmup_min_ms > 0 && warmup_min_ms < PERF_TARGET_MIN_MS) + if (warmup_min_ms > 0 && warmup_min_ms < target_min_ms) { - while (inner_loops * warmup_min_ms < PERF_TARGET_MIN_MS) + while (inner_loops * warmup_min_ms < target_min_ms) inner_loops *= 10; } } - double* times = new double[PERF_RUN_COUNT]; + double* times = new double[run_count]; double time_sum = 0; double time_min_val = DBL_MAX; double time_max_val = -DBL_MAX; - for (int i = 0; i < PERF_RUN_COUNT; i++) + for (int i = 0; i < run_count; i++) { double start = ncnn::get_current_time(); @@ -344,16 +409,16 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, if (t > time_max_val) time_max_val = t; } - sort_doubles(times, PERF_RUN_COUNT); + sort_doubles(times, run_count); double time_median_val; - if (PERF_RUN_COUNT % 2 == 0) - time_median_val = (times[PERF_RUN_COUNT / 2 - 1] + times[PERF_RUN_COUNT / 2]) / 2.0; + if (run_count % 2 == 0) + time_median_val = (times[run_count / 2 - 1] + times[run_count / 2]) / 2.0; else - time_median_val = times[PERF_RUN_COUNT / 2]; + time_median_val = times[run_count / 2]; result.time_min = time_min_val; result.time_max = time_max_val; - result.time_avg = time_sum / PERF_RUN_COUNT; + result.time_avg = time_sum / run_count; result.time_median = time_median_val; result.loop_count = inner_loops; @@ -409,6 +474,11 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, int forced_inner_loops, PerfResult& result) { + const int warmup_count = perf_env_int("NCNN_PERF_GPU_WARMUP_COUNT", PERF_GPU_WARMUP_COUNT, 1); + const int warmup_batch = perf_env_int("NCNN_PERF_GPU_WARMUP_BATCH", PERF_GPU_WARMUP_BATCH, 1); + const int run_count = perf_env_int("NCNN_PERF_RUN_COUNT", PERF_RUN_COUNT, 1); + const double target_min_ms = perf_env_double("NCNN_PERF_TARGET_MIN_MS", PERF_TARGET_MIN_MS, 0.0); + ncnn::Layer* op = ncnn::create_layer_vulkan(layer_type); if (!op) { @@ -515,17 +585,17 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, // warmup and calibrate inner loop count from warmup min time // batch multiple forwards per submit to amortize submit_and_wait overhead double warmup_min_ms = DBL_MAX; - for (int i = 0; i < PERF_GPU_WARMUP_COUNT; i++) + for (int i = 0; i < warmup_count; i++) { ncnn::VkCompute cmd(vkdev); - for (int b = 0; b < PERF_GPU_WARMUP_BATCH; b++) + for (int b = 0; b < warmup_batch; b++) { run_layer_forward_gpu(op, vk_inputs, top_blob_count, cmd, opt); } double t0 = ncnn::get_current_time(); cmd.submit_and_wait(); double t1 = ncnn::get_current_time(); - double t = (t1 - t0) / PERF_GPU_WARMUP_BATCH; + double t = (t1 - t0) / warmup_batch; if (t < warmup_min_ms) warmup_min_ms = t; } @@ -538,20 +608,20 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, { // calibrate inner loop count to power of 10, so total time >= PERF_TARGET_MIN_MS inner_loops = 1; - if (warmup_min_ms > 0 && warmup_min_ms < PERF_TARGET_MIN_MS) + if (warmup_min_ms > 0 && warmup_min_ms < target_min_ms) { - while (inner_loops * warmup_min_ms < PERF_TARGET_MIN_MS) + while (inner_loops * warmup_min_ms < target_min_ms) inner_loops *= 10; } } // record inner_loops forwards into one command buffer, single submit // this measures pure GPU kernel time, excluding per-launch overhead - double* times = new double[PERF_RUN_COUNT]; + double* times = new double[run_count]; double time_sum = 0; double time_min_val = DBL_MAX; double time_max_val = -DBL_MAX; - for (int i = 0; i < PERF_RUN_COUNT; i++) + for (int i = 0; i < run_count; i++) { ncnn::VkCompute cmd(vkdev); for (int k = 0; k < inner_loops; k++) @@ -571,16 +641,16 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, if (t > time_max_val) time_max_val = t; } - sort_doubles(times, PERF_RUN_COUNT); + sort_doubles(times, run_count); double time_median_val; - if (PERF_RUN_COUNT % 2 == 0) - time_median_val = (times[PERF_RUN_COUNT / 2 - 1] + times[PERF_RUN_COUNT / 2]) / 2.0; + if (run_count % 2 == 0) + time_median_val = (times[run_count / 2 - 1] + times[run_count / 2]) / 2.0; else - time_median_val = times[PERF_RUN_COUNT / 2]; + time_median_val = times[run_count / 2]; result.time_min = time_min_val; result.time_max = time_max_val; - result.time_avg = time_sum / PERF_RUN_COUNT; + result.time_avg = time_sum / run_count; result.time_median = time_median_val; result.loop_count = inner_loops; @@ -743,6 +813,15 @@ static const PrecisionConfig s_configs[] = { }; static const int s_num_configs = sizeof(s_configs) / sizeof(s_configs[0]); +static bool should_run_dtype(const char* label) +{ + const char* s = getenv("NCNN_PERF_DTYPE"); + if (!s || !s[0]) + return true; + + return strcmp(s, label) == 0; +} + static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, const std::vector& weights, const std::vector& inputs, @@ -754,6 +833,9 @@ static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, int cpu_inner_loops = 0; for (int i = 0; i < s_num_configs; i++) { + if (!should_run_dtype(s_configs[i].label)) + continue; + ncnn::Option opt = make_perf_option(s_configs[i].fp16_ps, s_configs[i].fp16_arith, s_configs[i].bf16); PerfResult result; @@ -792,6 +874,9 @@ static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, int gpu_inner_loops = 0; for (int i = 0; i < s_num_configs; i++) { + if (!should_run_dtype(s_configs[i].label)) + continue; + ncnn::Option opt = make_perf_option(s_configs[i].fp16_ps, s_configs[i].fp16_arith, s_configs[i].bf16); PerfResult result; @@ -843,3 +928,38 @@ void perf_layer(const char* layer_type, const ncnn::ParamDict& pd, perf_layer_impl(layer_type, pd, weights, inputs, 1, tag); } + +void perf_layer_int8(const char* layer_type, const ncnn::ParamDict& pd, + const std::vector& weights, + const std::vector& inputs, int top_blob_count, + const char* param_fmt, ...) +{ + char tag[256]; + va_list args; + va_start(args, param_fmt); + build_tag(tag, sizeof(tag), layer_type, inputs, param_fmt, args); + va_end(args); + + ncnn::Option opt; + opt.lightmode = true; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_packed = false; + opt.use_bf16_storage = false; + opt.use_vulkan_compute = false; + opt.use_winograd_convolution = true; + opt.use_sgemm_convolution = true; + opt.use_int8_inference = true; + + PerfResult result; + int ret = should_run_dtype("int8") ? perf_layer_cpu(layer_type, pd, weights, inputs, top_blob_count, opt, 0, result) : -1; + if (ret == 0) + { + char full_tag[512]; + snprintf(full_tag, sizeof(full_tag), "%s int8", tag); + print_perf_result(full_tag, result); + } +} diff --git a/tests/perf/perfutil.h b/tests/perf/perfutil.h index 06b31eeef508..e4ed236de19a 100644 --- a/tests/perf/perfutil.h +++ b/tests/perf/perfutil.h @@ -28,4 +28,10 @@ void perf_layer(const char* layer_type, const ncnn::ParamDict& pd, const std::vector& inputs, int top_blob_count, const char* param_fmt, ...); +// int8-only benchmark (does not test fp16/bf16 variants) +void perf_layer_int8(const char* layer_type, const ncnn::ParamDict& pd, + const std::vector& weights, + const std::vector& inputs, int top_blob_count, + const char* param_fmt, ...); + #endif // PERFUTIL_H diff --git a/tests/test_sdpa.cpp b/tests/test_sdpa.cpp index 3ebc8183ed93..8ff4cc900a9a 100644 --- a/tests/test_sdpa.cpp +++ b/tests/test_sdpa.cpp @@ -3,6 +3,8 @@ #include "testutil.h" +#include + static int test_sdpa(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0.f) { const int src_seqlen = q.h; @@ -47,7 +49,24 @@ static int test_sdpa_0() || test_sdpa(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f) || test_sdpa(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f) || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f) - || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f); + || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f) + || test_sdpa(RandomMat(32, 1, 8), RandomMat(32, 66, 8), RandomMat(20, 66, 8), 0) + || test_sdpa(RandomMat(26, 1, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1) + || test_sdpa(RandomMat(64, 1, 12), RandomMat(64, 128, 2), RandomMat(64, 128, 2), 0) + || test_sdpa(RandomMat(64, 1, 12), RandomMat(64, 127, 2), RandomMat(48, 127, 2), 1) + || test_sdpa(RandomMat(44, 1, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f) + || test_sdpa(RandomMat(12, 1, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f); +} + +static int test_sdpa_large_dim() +{ + if (!getenv("NCNN_TEST_SDPA_LARGE_DIM")) + return 0; + + return 0 + || test_sdpa(RandomMat(4096, 16, 32), RandomMat(4096, 16, 1), RandomMat(4096, 16, 1), 0, 1.f / 64.f) + || test_sdpa(RandomMat(4096, 16, 32), RandomMat(4096, 16, 4), RandomMat(4096, 16, 4), 0, 1.f / 64.f) + || test_sdpa(RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), 0, 1.f / 64.f); } #if NCNN_INT8 @@ -73,7 +92,7 @@ static int test_sdpa_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Ma as.push_back(RandomMat(dst_seqlen, src_seqlen)); } - float epsilon = 0.01; + float epsilon = 0.05; int ret = test_layer("SDPA", pd, weights, as, 1, epsilon); if (ret != 0) @@ -98,6 +117,17 @@ static int test_sdpa_1() || test_sdpa_int8(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f) || test_sdpa_int8(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f); } + +static int test_sdpa_int8_large_dim() +{ + if (!getenv("NCNN_TEST_SDPA_LARGE_DIM")) + return 0; + + return 0 + || test_sdpa_int8(RandomMat(4096, 16, 32), RandomMat(4096, 16, 1), RandomMat(4096, 16, 1), 0, 1.f / 64.f) + || test_sdpa_int8(RandomMat(4096, 16, 32), RandomMat(4096, 16, 4), RandomMat(4096, 16, 4), 0, 1.f / 64.f) + || test_sdpa_int8(RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), 0, 1.f / 64.f); +} #endif int main() @@ -105,8 +135,8 @@ int main() SRAND(7767517); #if NCNN_INT8 - return test_sdpa_0() || test_sdpa_1(); + return test_sdpa_0() || test_sdpa_1() || test_sdpa_large_dim() || test_sdpa_int8_large_dim(); #else - return test_sdpa_0(); + return test_sdpa_0() || test_sdpa_large_dim(); #endif } diff --git a/tests/test_sdpa_kvcache.cpp b/tests/test_sdpa_kvcache.cpp index 1fc84f9c72b3..0ad6da25eb7d 100644 --- a/tests/test_sdpa_kvcache.cpp +++ b/tests/test_sdpa_kvcache.cpp @@ -53,7 +53,13 @@ static int test_sdpa_0() || test_sdpa_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 0) || test_sdpa_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 0) || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3) - || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5); + || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5) + || test_sdpa_kvcache(RandomMat(32, 1, 8), RandomMat(32, 1, 8), RandomMat(20, 1, 8), 0, 11) + || test_sdpa_kvcache(RandomMat(26, 1, 8), RandomMat(26, 1, 8), RandomMat(18, 1, 8), 1, 11) + || test_sdpa_kvcache(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(64, 1, 2), 0, 1) + || test_sdpa_kvcache(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(48, 1, 2), 1, 1) + || test_sdpa_kvcache(RandomMat(44, 1, 4), RandomMat(44, 1, 4), RandomMat(55, 1, 4), 0, 0) + || test_sdpa_kvcache(RandomMat(12, 1, 4), RandomMat(12, 1, 4), RandomMat(55, 1, 4), 1, 0); } #if NCNN_INT8 @@ -85,7 +91,7 @@ static int test_sdpa_int8_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const as.push_back(RandomMat(embed_dim, past_seqlen, k.c)); as.push_back(RandomMat(out_embed_dim, past_seqlen, v.c)); - float epsilon = 0.01; + float epsilon = 0.05; int ret = test_layer("SDPA", pd, weights, as, 3, epsilon); if (ret != 0)