From 622928c6aea123a7d90e3bc2f1d026885ce36433 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Fri, 24 Apr 2026 17:08:57 +0800 Subject: [PATCH 01/53] basic fa2 --- src/layer/x86/sdpa_x86.cpp | 548 +++++++++++++++++++------------ src/layer/x86/sdpa_x86.h | 6 - tests/perf/perf_sdpa_decode.cpp | 31 +- tests/perf/perf_sdpa_prefill.cpp | 32 +- tests/test_sdpa.cpp | 2 +- tests/test_sdpa_kvcache.cpp | 2 +- 6 files changed, 360 insertions(+), 261 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index a319913c3f9e..e530764dd50e 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -5,6 +5,11 @@ #include "layer_type.h" +#include "cpu.h" + +#include +#include + namespace ncnn { SDPA_x86::SDPA_x86() @@ -12,124 +17,54 @@ SDPA_x86::SDPA_x86() #if NCNN_BF16 support_bf16_storage = true; #endif - - qk_gemm = 0; - qkv_gemm = 0; - qk_softmax = 0; } -int SDPA_x86::create_pipeline(const Option& _opt) +int SDPA_x86::create_pipeline(const Option& /*_opt*/) { - Option opt = _opt; if (int8_scale_term) { - opt.use_packing_layout = false; // TODO enable packing support_bf16_storage = false; } - { - qk_softmax = ncnn::create_layer_cpu(ncnn::LayerType::Softmax); - ncnn::ParamDict pd; - pd.set(0, -1); // axis - pd.set(1, 1); - qk_softmax->load_param(pd); - qk_softmax->load_model(ModelBinFromMatArray(0)); - qk_softmax->create_pipeline(opt); - } - - // Q * K^T - if (scale != 0.f) - { - qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - - pd.set(0, scale); // alpha - pd.set(1, 1.f / scale); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 -#if NCNN_INT8 - pd.set(18, int8_scale_term); -#endif - qk_gemm->load_param(pd); - qk_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; - opt1.num_threads = 1; - qk_gemm->create_pipeline(opt1); - } - - // Attn * V - { - qkv_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - pd.set(0, 1.f); // alpha - pd.set(1, 1.f); // beta - pd.set(2, 0); // transA (Attn: Seq x Seq) - pd.set(3, 0); // transB (V: Seq x Embed) => Attn * V - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, 1); // constantC (None) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, -1); // constant_broadcast_type_C - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 - pd.set(14, 0); // output_transpose -#if NCNN_INT8 - pd.set(18, int8_scale_term); -#endif - qkv_gemm->load_param(pd); - qkv_gemm->load_model(ModelBinFromMatArray(0)); - Option opt1 = opt; - opt1.num_threads = 1; - qkv_gemm->create_pipeline(opt1); - } - return 0; } -int SDPA_x86::destroy_pipeline(const Option& _opt) +int SDPA_x86::destroy_pipeline(const Option& /*_opt*/) { - Option opt = _opt; - if (int8_scale_term) - { - opt.use_packing_layout = false; // TODO enable packing - } - - if (qk_softmax) - { - qk_softmax->destroy_pipeline(opt); - delete qk_softmax; - qk_softmax = 0; - } + return 0; +} - if (qk_gemm) - { - qk_gemm->destroy_pipeline(opt); - delete qk_gemm; - qk_gemm = 0; - } +#if NCNN_INT8 +static inline signed char float2int8(float v) +{ + int int32 = static_cast(round(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} - if (qkv_gemm) +static void dynamic_quantize_blockwise(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + for (int b = 0; b < num_blocks; b++) { - qkv_gemm->destroy_pipeline(opt); - delete qkv_gemm; - qkv_gemm = 0; + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + for (int i = start; i < end; i++) + { + absmax = std::max(absmax, (float)fabs(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + for (int i = start; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } } - - return 0; } +#endif // NCNN_INT8 int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { @@ -204,138 +139,341 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } const int num_heads_per_group = num_heads / num_group; + const float _scale = scale == 0.f ? 1.f / sqrtf((float)embed_dim) : scale; - Mat qk_cross(dst_seqlen, src_seqlen, num_heads, 4u, opt.workspace_allocator); - if (qk_cross.empty()) - return -100; + const int BLOCK_M = 64; + const int BLOCK_N = 64; - std::vector retqks(num_heads); + Mat& top_blob = top_blobs[0]; + top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; - // Dynamic Scale Calculation and Beta Correction - Layer* _qk_gemm = qk_gemm; - if (scale == 0.f) - { - float _scale = 1.f / sqrt(embed_dim); - - _qk_gemm = ncnn::create_layer_cpu(ncnn::LayerType::Gemm); - ncnn::ParamDict pd; - - pd.set(0, _scale); // alpha - pd.set(1, 1.f / _scale); // beta - pd.set(2, 0); // transA (Q: Seq x Embed) - pd.set(3, 1); // transB (K: Seq x Embed -> K^T: Embed x Seq) => Q * K^T - pd.set(4, 0); // constantA - pd.set(5, 0); // constantB - pd.set(6, attn_mask ? 0 : 1); // constantC (if mask exists, use it) - pd.set(7, 0); // M - pd.set(8, 0); // N - pd.set(9, 0); // K - pd.set(10, attn_mask ? 3 : -1); // constant_broadcast_type_C (MxN) - pd.set(11, 0); // output_N1M - pd.set(12, 1); // output_elempack - pd.set(13, 1); // output_elemtype = fp32 #if NCNN_INT8 - pd.set(18, int8_scale_term); -#endif - _qk_gemm->load_param(pd); - _qk_gemm->load_model(ModelBinFromMatArray(0)); + if (int8_scale_term) + { + const int qk_num_blocks = (embed_dim + 31) / 32; + const int v_num_blocks = (out_embed_dim + 31) / 32; - Option opt1 = opt; - opt1.num_threads = 1; - _qk_gemm->create_pipeline(opt1); - } + Mat key_int8(embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); + Mat key_scales(qk_num_blocks, dst_seqlen, num_group, 4u, opt.blob_allocator); + Mat value_int8(out_embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); + Mat value_scales(v_num_blocks, dst_seqlen, num_group, 4u, opt.blob_allocator); - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_heads; i++) - { - // 1. Q * K^T - std::vector qk_bottom_blobs; - qk_bottom_blobs.push_back(query.channel(i)); // Q: [Seq, Embed] - qk_bottom_blobs.push_back(key.channel(i / num_heads_per_group)); // K: [DstSeq, Embed] + if (key_int8.empty() || key_scales.empty() || value_int8.empty() || value_scales.empty()) + return -100; - if (attn_mask) + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) { - // Ensure mask is 2D for Gemm auto-broadcast detection - Mat maskm = attn_mask_blob; - if (maskm.dims == 3) + const Mat key_head = key.channel(g); + Mat key_int8_head = key_int8.channel(g); + Mat key_scales_head = key_scales.channel(g); + for (int j = 0; j < dst_seqlen; j++) { - // If c > 1, pick i-th head mask. If c == 1, pick 0-th (broadcast) - maskm = maskm.channel(maskm.c > 1 ? i : 0); + dynamic_quantize_blockwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + } + + const Mat value_head = value.channel(g); + Mat value_int8_head = value_int8.channel(g); + Mat value_scales_head = value_scales.channel(g); + for (int j = 0; j < dst_seqlen; j++) + { + dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); } - qk_bottom_blobs.push_back(maskm); } - std::vector qk_top_blobs(1); - qk_top_blobs[0] = qk_cross.channel(i); + Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N, opt.num_threads, 4u, opt.workspace_allocator); + Mat q_int8_tile(embed_dim, BLOCK_M, opt.num_threads, 1u, opt.workspace_allocator); + Mat q_scales_tile(qk_num_blocks, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Option opt1 = opt; - opt1.num_threads = 1; - opt1.blob_allocator = qk_cross.allocator; - retqks[i] = _qk_gemm->forward(qk_bottom_blobs, qk_top_blobs, opt1); - } + if (o_accum.empty() || s_vec.empty() || q_int8_tile.empty() || q_scales_tile.empty()) + return -100; - if (scale == 0.f) - { - Option opt1 = opt; - opt1.num_threads = 1; - _qk_gemm->destroy_pipeline(opt1); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + Mat mask_head; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + } - delete _qk_gemm; - _qk_gemm = 0; - } + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); - for (int i = 0; i < num_heads; i++) - { - if (retqks[i] != 0) - return retqks[i]; - } + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + { + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; + + for (int i = 0; i < block_m; i++) + { + dynamic_quantize_blockwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + for (int k = 0; k < out_embed_dim; k++) + { + optr[k] = 0.f; + } + } + + float m_vec[64]; + float l_vec[64]; + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + for (int i = 0; i < block_m; i++) + { + const signed char* qptr = q_int8_tile_head.row(i); + const float* qscales = q_scales_tile_head.row(i); + + for (int j = 0; j < block_n; j++) + { + const signed char* kptr = key_int8_head.row(n_start + j); + const float* kscales = key_scales_head.row(n_start + j); + + float sum = 0.f; + for (int b = 0; b < qk_num_blocks; b++) + { + int k_start = b * 32; + int k_end = k_start + 32 < embed_dim ? k_start + 32 : embed_dim; + int block_sum = 0; + for (int k = k_start; k < k_end; k++) + { + block_sum += qptr[k] * kptr[k]; + } + sum += (float)block_sum / (qscales[b] * kscales[b]); + } + s_vec_ptr[j] = sum * _scale; + } + + if (attn_mask) + { + const float* mptr = mask_head.row(m_start + i) + n_start; + for (int j = 0; j < block_n; j++) + { + s_vec_ptr[j] += mptr[j]; + } + } + + float m_new = m_vec[i]; + for (int j = 0; j < block_n; j++) + { + m_new = std::max(m_new, s_vec_ptr[j]); + } + + float scale_factor = expf(m_vec[i] - m_new); + float l_new = l_vec[i] * scale_factor; + + float* optr = o_accum_head.row(i); + for (int k = 0; k < out_embed_dim; k++) + { + optr[k] *= scale_factor; + } + + for (int j = 0; j < block_n; j++) + { + float p = expf(s_vec_ptr[j] - m_new); + l_new += p; + + const signed char* vptr = value_int8_head.row(n_start + j); + const float* vscales = value_scales_head.row(n_start + j); + for (int vb = 0; vb < v_num_blocks; vb++) + { + float inv_scale = 1.f / vscales[vb]; + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + for (int k = k_start; k < k_end; k++) + { + optr[k] += p * (float)vptr[k] * inv_scale; + } + } + } + + m_vec[i] = m_new; + l_vec[i] = l_new; + } + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + for (int k = 0; k < out_embed_dim; k++) + { + outptr[k] = optr[k] * inv_l; + } + } + } + } - // 2. Softmax - int retqk = qk_softmax->forward_inplace(qk_cross, opt); - if (retqk != 0) - return retqk; + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } - Mat value_fp32 = value; -#if NCNN_BF16 - if (opt.use_bf16_storage && value.elembits() == 16) - { - // qkv_gemm need fp32 inputs - cast_bfloat16_to_float32(value, value_fp32, opt); - if (value_fp32.empty()) - return -100; + return 0; } -#endif +#endif // NCNN_INT8 - Mat& top_blob = top_blobs[0]; - top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; + Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N, opt.num_threads, 4u, opt.workspace_allocator); - // 3. Attn * V - std::vector retqkvs(num_heads); + if (o_accum.empty() || s_vec.empty()) + return -100; #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < num_heads; i++) + for (int q = 0; q < num_heads; q++) { - std::vector qkv_bottom_blobs(2); - qkv_bottom_blobs[0] = qk_cross.channel(i); // Attn: [DstSeq, Seq] - qkv_bottom_blobs[1] = value_fp32.channel(i / num_heads_per_group); // V: [DstSeq, OutEmbed] + const Mat query_head = query.channel(q); + const Mat key_head = key.channel(q / num_heads_per_group); + const Mat value_head = value.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); - std::vector qkv_top_blobs(1); - qkv_top_blobs[0] = top_blob.channel(i); // Output + Mat mask_head; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + } - Option opt1 = opt; - opt1.num_threads = 1; - retqkvs[i] = qkv_gemm->forward(qkv_bottom_blobs, qkv_top_blobs, opt1); - } + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - for (int i = 0; i < num_heads; i++) - { - if (retqkvs[i] != 0) - return retqkvs[i]; - } + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + { + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; - value_fp32.release(); + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + for (int k = 0; k < out_embed_dim; k++) + { + optr[k] = 0.f; + } + } + + float m_vec[64]; + float l_vec[64]; + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + for (int i = 0; i < block_m; i++) + { + const float* qptr = query_head.row(m_start + i); + + for (int j = 0; j < block_n; j++) + { + const float* kptr = key_head.row(n_start + j); + float sum = 0.f; + for (int k = 0; k < embed_dim; k++) + { + sum += qptr[k] * kptr[k]; + } + s_vec_ptr[j] = sum * _scale; + } + + if (attn_mask) + { + const float* mptr = mask_head.row(m_start + i) + n_start; + for (int j = 0; j < block_n; j++) + { + s_vec_ptr[j] += mptr[j]; + } + } + + float m_new = m_vec[i]; + for (int j = 0; j < block_n; j++) + { + m_new = std::max(m_new, s_vec_ptr[j]); + } + + float scale_factor = expf(m_vec[i] - m_new); + float l_new = l_vec[i] * scale_factor; + + float* optr = o_accum_head.row(i); + for (int k = 0; k < out_embed_dim; k++) + { + optr[k] *= scale_factor; + } + + for (int j = 0; j < block_n; j++) + { + float p = expf(s_vec_ptr[j] - m_new); + l_new += p; + + const float* vptr = value_head.row(n_start + j); + for (int k = 0; k < out_embed_dim; k++) + { + optr[k] += p * vptr[k]; + } + } + + m_vec[i] = m_new; + l_vec[i] = l_new; + } + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + for (int k = 0; k < out_embed_dim; k++) + { + outptr[k] = optr[k] * inv_l; + } + } + } + } if (kv_cache) { diff --git a/src/layer/x86/sdpa_x86.h b/src/layer/x86/sdpa_x86.h index 1b8f28e11f66..11f329dc0175 100644 --- a/src/layer/x86/sdpa_x86.h +++ b/src/layer/x86/sdpa_x86.h @@ -17,12 +17,6 @@ class SDPA_x86 : public SDPA virtual int destroy_pipeline(const Option& opt); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - -public: - Layer* qk_gemm; - Layer* qkv_gemm; - - Layer* qk_softmax; }; } // namespace ncnn diff --git a/tests/perf/perf_sdpa_decode.cpp b/tests/perf/perf_sdpa_decode.cpp index c670fa53c671..7910c7ef0179 100644 --- a/tests/perf/perf_sdpa_decode.cpp +++ b/tests/perf/perf_sdpa_decode.cpp @@ -52,33 +52,20 @@ int main() // larger model (e.g., 7B scale) perf_sdpa_decode(4096, 32, 32, 0); - perf_sdpa_decode(4096, 32, 32, 128); - perf_sdpa_decode(4096, 32, 32, 512); - perf_sdpa_decode(4096, 32, 32, 1024); - perf_sdpa_decode(4096, 32, 32, 2048); - perf_sdpa_decode(4096, 32, 32, 4096); - perf_sdpa_decode(4096, 32, 32, 8192); + perf_sdpa_decode(4096, 32, 32, 16); + perf_sdpa_decode(4096, 32, 32, 32); + perf_sdpa_decode(4096, 32, 32, 64); // GQA/MQA configurations // GQA: num_groups < num_heads - perf_sdpa_decode(4096, 32, 4, 128); - perf_sdpa_decode(4096, 32, 4, 512); - perf_sdpa_decode(4096, 32, 4, 1024); - perf_sdpa_decode(4096, 32, 4, 2048); - perf_sdpa_decode(4096, 32, 4, 4096); + perf_sdpa_decode(4096, 32, 4, 16); + perf_sdpa_decode(4096, 32, 4, 32); + perf_sdpa_decode(4096, 32, 4, 64); // MQA: num_groups = 1 - perf_sdpa_decode(4096, 32, 1, 128); - perf_sdpa_decode(4096, 32, 1, 512); - perf_sdpa_decode(4096, 32, 1, 1024); - perf_sdpa_decode(4096, 32, 1, 2048); - perf_sdpa_decode(4096, 32, 1, 4096); - - // very large context lengths - perf_sdpa_decode(4096, 32, 32, 16384); - perf_sdpa_decode(4096, 32, 32, 32768); - perf_sdpa_decode(4096, 32, 4, 16384); - perf_sdpa_decode(4096, 32, 4, 32768); + perf_sdpa_decode(4096, 32, 1, 16); + perf_sdpa_decode(4096, 32, 1, 32); + perf_sdpa_decode(4096, 32, 1, 64); return 0; } diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index 6b5c5b08e306..31212c248ab1 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -53,37 +53,17 @@ int main() perf_sdpa_prefill(4096, 32, 32, 16); perf_sdpa_prefill(4096, 32, 32, 32); perf_sdpa_prefill(4096, 32, 32, 64); - perf_sdpa_prefill(4096, 32, 32, 128); - perf_sdpa_prefill(4096, 32, 32, 256); - perf_sdpa_prefill(4096, 32, 32, 512); - perf_sdpa_prefill(4096, 32, 32, 1024); - perf_sdpa_prefill(4096, 32, 32, 2048); - perf_sdpa_prefill(4096, 32, 32, 4096); // GQA/MQA configurations // GQA: num_groups < num_heads - perf_sdpa_prefill(4096, 32, 4, 128); - perf_sdpa_prefill(4096, 32, 4, 256); - perf_sdpa_prefill(4096, 32, 4, 512); - perf_sdpa_prefill(4096, 32, 4, 1024); - perf_sdpa_prefill(4096, 32, 4, 2048); - perf_sdpa_prefill(4096, 32, 4, 4096); + perf_sdpa_prefill(4096, 32, 4, 16); + perf_sdpa_prefill(4096, 32, 4, 32); + perf_sdpa_prefill(4096, 32, 4, 64); // MQA: num_groups = 1 - perf_sdpa_prefill(4096, 32, 1, 128); - perf_sdpa_prefill(4096, 32, 1, 256); - perf_sdpa_prefill(4096, 32, 1, 512); - perf_sdpa_prefill(4096, 32, 1, 1024); - perf_sdpa_prefill(4096, 32, 1, 2048); - perf_sdpa_prefill(4096, 32, 1, 4096); - - // very long sequences - perf_sdpa_prefill(4096, 32, 32, 8192); - perf_sdpa_prefill(4096, 32, 32, 16384); - perf_sdpa_prefill(4096, 32, 32, 32768); - perf_sdpa_prefill(4096, 32, 4, 8192); - perf_sdpa_prefill(4096, 32, 4, 16384); - perf_sdpa_prefill(4096, 32, 4, 32768); + perf_sdpa_prefill(4096, 32, 1, 16); + perf_sdpa_prefill(4096, 32, 1, 32); + perf_sdpa_prefill(4096, 32, 1, 64); return 0; } diff --git a/tests/test_sdpa.cpp b/tests/test_sdpa.cpp index 3ebc8183ed93..04cb92ae2a24 100644 --- a/tests/test_sdpa.cpp +++ b/tests/test_sdpa.cpp @@ -73,7 +73,7 @@ static int test_sdpa_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Ma as.push_back(RandomMat(dst_seqlen, src_seqlen)); } - float epsilon = 0.01; + float epsilon = 0.05; int ret = test_layer("SDPA", pd, weights, as, 1, epsilon); if (ret != 0) diff --git a/tests/test_sdpa_kvcache.cpp b/tests/test_sdpa_kvcache.cpp index 1fc84f9c72b3..3e0a20221ad0 100644 --- a/tests/test_sdpa_kvcache.cpp +++ b/tests/test_sdpa_kvcache.cpp @@ -85,7 +85,7 @@ static int test_sdpa_int8_kvcache(const ncnn::Mat& q, const ncnn::Mat& k, const as.push_back(RandomMat(embed_dim, past_seqlen, k.c)); as.push_back(RandomMat(out_embed_dim, past_seqlen, v.c)); - float epsilon = 0.01; + float epsilon = 0.05; int ret = test_layer("SDPA", pd, weights, as, 3, epsilon); if (ret != 0) From ad7e388683ed46705fc8bdc91961de46ddddfe0b Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sat, 25 Apr 2026 11:35:20 +0800 Subject: [PATCH 02/53] basic fp32 --- ggml | 1 + src/layer/x86/sdpa_x86.cpp | 1755 ++++++++++++++++++++++++++++++++++-- 2 files changed, 1675 insertions(+), 81 deletions(-) create mode 160000 ggml diff --git a/ggml b/ggml new file mode 160000 index 000000000000..8be60f83ec12 --- /dev/null +++ b/ggml @@ -0,0 +1 @@ +Subproject commit 8be60f83ec124c31f3a427053c29022e3072f8a4 diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e530764dd50e..f2a64b54d9aa 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -6,9 +6,24 @@ #include "layer_type.h" #include "cpu.h" +#include "x86_usability.h" + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ #include +#include #include +#include namespace ncnn { @@ -66,6 +81,1557 @@ static void dynamic_quantize_blockwise(const float* src, signed char* dst, float } #endif // NCNN_INT8 + +static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + for (int j = 0; j < n; j++) + { + const float* kptr = K + j * d; + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + +static inline void pv_gemm_scalar(float* O, const float* P, const float* V, + int m, int n, int d) +{ + for (int i = 0; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + { + float p = pptr[j]; + const float* vptr = V + j * d; + for (int k = 0; k < d; k++) + optr[k] += p * vptr[k]; + } + } +} + +static inline void vec_scale_scalar(float* x, float s, int n) +{ + for (int i = 0; i < n; i++) x[i] *= s; +} + +static inline void vec_zero_scalar(float* x, int n) +{ + for (int i = 0; i < n; i++) x[i] = 0.f; +} + +static inline void softmax_tile_scalar(float* P, const float* S, + float* m_vec, float* l_vec, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + float m_new = m_vec[i]; + for (int j = 0; j < n; j++) m_new = std::max(m_new, sptr[j]); + float scale_factor = expf(m_vec[i] - m_new); + l_vec[i] *= scale_factor; + float l_add = 0.f; + for (int j = 0; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; + } +} + +#if __AVX512F__ + +static void qk_gemm_avx512(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + const float* k2 = K + (j + 2) * d; + const float* k3 = K + (j + 3) * d; + + for (int mi = 0; mi < 4; mi++) + { + const float* qptr = Q + (i + mi) * d; + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + float sum0 = _mm512_comp_reduce_add_ps(acc0); + float sum1 = _mm512_comp_reduce_add_ps(acc1); + float sum2 = _mm512_comp_reduce_add_ps(acc2); + float sum3 = _mm512_comp_reduce_add_ps(acc3); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + sum2 += qv * k2[k]; + sum3 += qv * k3[k]; + } + + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + S[(i + mi) * n + j + 2] = sum2 * scale; + S[(i + mi) * n + j + 3] = sum3 * scale; + } + } + + for (; j < n; j++) + { + for (int mi = 0; mi < 4; mi++) + { + const float* qptr = Q + (i + mi) * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m512 vacc = _mm512_setzero_ps(); + for (; k + 15 < d; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + sum = _mm512_comp_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * d; + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + const float* k2 = K + (j + 2) * d; + const float* k3 = K + (j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + float sum0 = _mm512_comp_reduce_add_ps(acc0); + float sum1 = _mm512_comp_reduce_add_ps(acc1); + float sum2 = _mm512_comp_reduce_add_ps(acc2); + float sum3 = _mm512_comp_reduce_add_ps(acc3); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + sum2 += qv * k2[k]; + sum3 += qv * k3[k]; + } + + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + S[i * n + j + 2] = sum2 * scale; + S[i * n + j + 3] = sum3 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m512 vacc = _mm512_setzero_ps(); + for (; k + 15 < d; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + sum = _mm512_comp_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + + +template +static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc[4][4]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j < n; j++) + { + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + const float* kptr = K + j * D; + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < D; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + + +template +static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 16; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m512 acc[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm512_loadu_ps(op[mi] + dd + vi * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + pvec[mi] = _mm512_set1_ps(pptr[mi][j]); + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi][vi] = _mm512_fmadd_ps(pvec[mi], vvec, acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(op[mi] + dd + vi * 16, acc[mi][vi]); + } + + for (; dd + 15 < d; dd += 16) + { + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < d; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m512 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_loadu_ps(optr + dd + vi * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(optr + dd + vi * 16, acc[vi]); + } + + for (; dd + 15 < d; dd += 16) + { + __m512 acc = _mm512_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); + _mm512_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; + } + } +} + + +template +static __attribute__((noinline)) void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 16; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc[M_BLOCK][8]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + dd + 0 * 16); + acc[mi][1] = _mm512_loadu_ps(op[mi] + dd + 1 * 16); + acc[mi][2] = _mm512_loadu_ps(op[mi] + dd + 2 * 16); + acc[mi][3] = _mm512_loadu_ps(op[mi] + dd + 3 * 16); + acc[mi][4] = _mm512_loadu_ps(op[mi] + dd + 4 * 16); + acc[mi][5] = _mm512_loadu_ps(op[mi] + dd + 5 * 16); + acc[mi][6] = _mm512_loadu_ps(op[mi] + dd + 6 * 16); + acc[mi][7] = _mm512_loadu_ps(op[mi] + dd + 7 * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 v0 = _mm512_loadu_ps(V + j * D + dd + 0 * 16); + __m512 v1 = _mm512_loadu_ps(V + j * D + dd + 1 * 16); + __m512 v2 = _mm512_loadu_ps(V + j * D + dd + 2 * 16); + __m512 v3 = _mm512_loadu_ps(V + j * D + dd + 3 * 16); + __m512 v4 = _mm512_loadu_ps(V + j * D + dd + 4 * 16); + __m512 v5 = _mm512_loadu_ps(V + j * D + dd + 5 * 16); + __m512 v6 = _mm512_loadu_ps(V + j * D + dd + 6 * 16); + __m512 v7 = _mm512_loadu_ps(V + j * D + dd + 7 * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); + acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); + acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); + acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); + acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + _mm512_storeu_ps(op[mi] + dd + 0 * 16, acc[mi][0]); + _mm512_storeu_ps(op[mi] + dd + 1 * 16, acc[mi][1]); + _mm512_storeu_ps(op[mi] + dd + 2 * 16, acc[mi][2]); + _mm512_storeu_ps(op[mi] + dd + 3 * 16, acc[mi][3]); + _mm512_storeu_ps(op[mi] + dd + 4 * 16, acc[mi][4]); + _mm512_storeu_ps(op[mi] + dd + 5 * 16, acc[mi][5]); + _mm512_storeu_ps(op[mi] + dd + 6 * 16, acc[mi][6]); + _mm512_storeu_ps(op[mi] + dd + 7 * 16, acc[mi][7]); + } + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * D + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < D; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * D + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); + acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); + acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); + acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); + acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); + acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); + acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); + acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); + } + + _mm512_storeu_ps(optr + dd + 0 * 16, acc0); + _mm512_storeu_ps(optr + dd + 1 * 16, acc1); + _mm512_storeu_ps(optr + dd + 2 * 16, acc2); + _mm512_storeu_ps(optr + dd + 3 * 16, acc3); + _mm512_storeu_ps(optr + dd + 4 * 16, acc4); + _mm512_storeu_ps(optr + dd + 5 * 16, acc5); + _mm512_storeu_ps(optr + dd + 6 * 16, acc6); + _mm512_storeu_ps(optr + dd + 7 * 16, acc7); + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc = _mm512_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); + _mm512_storeu_ps(optr + dd, acc); + } + + for (; dd < D; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * D + dd]; + optr[dd] = acc; + } + } +} + + +static inline void softmax_tile_avx512(float* P, const float* S, + float* m_vec, float* l_vec, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); + float m_new = _mm512_comp_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + l_vec[i] *= scale_factor; + + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < n; j += 16) + { + __m512 svec = _mm512_loadu_ps(sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_storeu_ps(pptr + j, evec); + vsum = _mm512_add_ps(vsum, evec); + } + float l_add = _mm512_comp_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; + } +} + +static inline void vec_scale_avx512(float* x, float s, int n) +{ + __m512 vscale = _mm512_set1_ps(s); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; +} + +static inline void vec_zero_avx512(float* x, int n) +{ + __m512 zero = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; +} + +#endif // __AVX512F__ + +#if __AVX__ + +static void qk_gemm_avx(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + for (int mi = 0; mi < 6; mi++) + { + const float* qptr = Q + (i + mi) * d; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + } + } + + for (; j < n; j++) + { + for (int mi = 0; mi < 6; mi++) + { + const float* qptr = Q + (i + mi) * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + + +template +static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + for (int mi = 0; mi < 6; mi++) + { + const float* qptr = Q + (i + mi) * D; + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + } + + for (; j < n; j++) + { + for (int mi = 0; mi < 6; mi++) + { + const float* qptr = Q + (i + mi) * D; + const float* kptr = K + j * D; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < D; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[(i + mi) * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < D; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + + +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[vi]); + } + } + + for (; dd + 7 < d; dd += 8) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc = _mm256_loadu_ps(op[mi] + dd); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(op[mi] + dd, acc); + } + } + + for (; dd < d; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m256 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); + } + + for (; dd + 7 < d; dd += 8) + { + __m256 acc = _mm256_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; + } + } +} + + +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(optr + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(optr + vi * 8, acc[vi]); + } +} + + +static inline void softmax_tile_avx(float* P, const float* S, + float* m_vec, float* l_vec, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + l_vec[i] *= scale_factor; + + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < n; j += 8) + { + __m256 svec = _mm256_loadu_ps(sptr + j); + __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); + _mm256_storeu_ps(pptr + j, evec); + vsum = _mm256_add_ps(vsum, evec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; + } +} + +static inline void vec_scale_avx(float* x, float s, int n) +{ + __m256 vscale = _mm256_set1_ps(s); + int i = 0; + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; +} + +static inline void vec_zero_avx(float* x, int n) +{ + __m256 zero = _mm256_setzero_ps(); + int i = 0; + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; +} + +#endif // __AVX__ + +#if __SSE2__ + +static void qk_gemm_sse2(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j < n; j++) + { + const float* kptr = K + j * d; + + for (int mi = 0; mi < 4; mi++) + { + const float* qptr = Q + (i + mi) * d; + __m128 acc0 = _mm_setzero_ps(); + + int k = 0; + for (; k + 3 < d; k += 4) + { + acc0 = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), acc0); + } + + float sum0 = _mm_reduce_add_ps(acc0); + + for (; k < d; k++) + { + sum0 += qptr[k] * kptr[k]; + } + + S[(i + mi) * n + j] = sum0 * scale; + } + } + } + + for (; i < m; i++) + { + for (int j = 0; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m128 vacc = _mm_setzero_ps(); + for (; k + 3 < d; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + sum = _mm_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + + +template +static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + for (int j = 0; j < n; j++) + { + const float* kptr = K + j * D; + + const float* q0 = Q + (i + 0) * D; + const float* q1 = Q + (i + 1) * D; + const float* q2 = Q + (i + 2) * D; + const float* q3 = Q + (i + 3) * D; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + + __m128 qvec = _mm_loadu_ps(q0 + k); + acc0 = _mm_comp_fmadd_ps(qvec, kvec, acc0); + + qvec = _mm_loadu_ps(q1 + k); + acc1 = _mm_comp_fmadd_ps(qvec, kvec, acc1); + + qvec = _mm_loadu_ps(q2 + k); + acc2 = _mm_comp_fmadd_ps(qvec, kvec, acc2); + + qvec = _mm_loadu_ps(q3 + k); + acc3 = _mm_comp_fmadd_ps(qvec, kvec, acc3); + } + + S[(i + 0) * n + j] = _mm_reduce_add_ps(acc0) * scale; + S[(i + 1) * n + j] = _mm_reduce_add_ps(acc1) * scale; + S[(i + 2) * n + j] = _mm_reduce_add_ps(acc2) * scale; + S[(i + 3) * n + j] = _mm_reduce_add_ps(acc3) * scale; + } + } + + for (; i < m; i++) + { + for (int j = 0; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < D; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; + } + } +} + + +template +static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 4; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_loadu_ps(op[mi] + dd + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm_storeu_ps(op[mi] + dd + vi * 4, acc[vi]); + } + } + + for (; dd + 3 < d; dd += 4) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 acc = _mm_loadu_ps(op[mi] + dd); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), _mm_loadu_ps(V + j * d + dd), acc); + _mm_storeu_ps(op[mi] + dd, acc); + } + } + + for (; dd < d; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m128 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_loadu_ps(optr + dd + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm_storeu_ps(optr + dd + vi * 4, acc[vi]); + } + + for (; dd + 3 < d; dd += 4) + { + __m128 acc = _mm_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), _mm_loadu_ps(V + j * d + dd), acc); + _mm_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; + } + } +} + + +template +static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 4; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_loadu_ps(op[mi] + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * D + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm_storeu_ps(op[mi] + vi * 4, acc[vi]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m128 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_loadu_ps(optr + vi * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * D + vi * 4), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm_storeu_ps(optr + vi * 4, acc[vi]); + } +} + + +static inline void softmax_tile_sse2(float* P, const float* S, + float* m_vec, float* l_vec, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(sptr + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + l_vec[i] *= scale_factor; + + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < n; j += 4) + { + __m128 svec = _mm_loadu_ps(sptr + j); + __m128 evec = exp_ps(_mm_sub_ps(svec, vm_new)); + _mm_storeu_ps(pptr + j, evec); + vsum = _mm_add_ps(vsum, evec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; + } +} + +static inline void vec_scale_sse2(float* x, float s, int n) +{ + __m128 vscale = _mm_set1_ps(s); + int i = 0; + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; +} + +static inline void vec_zero_sse2(float* x, int n) +{ + __m128 zero = _mm_setzero_ps(); + int i = 0; + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; +} + +#endif // __SSE2__ + + +static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + if (d == 128) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<128>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 64) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<64>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 512) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<512>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(S, Q, K, m, n, scale); + return; +#endif + } + +#if __AVX512F__ + qk_gemm_avx512(S, Q, K, m, n, d, scale); +#elif __AVX__ + qk_gemm_avx(S, Q, K, m, n, d, scale); +#elif __SSE2__ + qk_gemm_sse2(S, Q, K, m, n, d, scale); +#else + qk_gemm_scalar(S, Q, K, m, n, d, scale); +#endif +} + +static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, + int m, int n, int d) +{ + if (d == 128) + { +#if __AVX512F__ + pv_gemm_avx512<2, 128>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 128>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<4, 128>(O, P, V, m, n); + return; +#endif + } + if (d == 64) + { +#if __AVX512F__ + pv_gemm_avx512<2, 64>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 64>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(O, P, V, m, n); + return; +#endif + } + if (d == 512) + { +#if __AVX512F__ + pv_gemm_avx512<2, 512>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 512>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<4, 512>(O, P, V, m, n); + return; +#endif + } + +#if __AVX512F__ + pv_gemm_avx512<2, 128>(O, P, V, m, n, d); +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); +#elif __SSE2__ + pv_gemm_sse2<4, 4>(O, P, V, m, n, d); +#else + pv_gemm_scalar(O, P, V, m, n, d); +#endif +} + +static inline void vec_scale_dispatch(float* x, float s, int n) +{ +#if __AVX512F__ + vec_scale_avx512(x, s, n); +#elif __AVX__ + vec_scale_avx(x, s, n); +#elif __SSE2__ + vec_scale_sse2(x, s, n); +#else + vec_scale_scalar(x, s, n); +#endif +} + +static inline void vec_zero_dispatch(float* x, int n) +{ +#if __AVX512F__ + vec_zero_avx512(x, n); +#elif __AVX__ + vec_zero_avx(x, n); +#elif __SSE2__ + vec_zero_sse2(x, n); +#else + vec_zero_scalar(x, n); +#endif +} + +static inline void softmax_tile_dispatch(float* P, const float* S, + float* m_vec, float* l_vec, int m, int n) +{ +#if __AVX512F__ + softmax_tile_avx512(P, S, m_vec, l_vec, m, n); +#elif __AVX__ + softmax_tile_avx(P, S, m_vec, l_vec, m, n); +#elif __SSE2__ + softmax_tile_sse2(P, S, m_vec, l_vec, m, n); +#else + softmax_tile_scalar(P, S, m_vec, l_vec, m, n); +#endif +} + +// Timing instrumentation removed + int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { Option opt = _opt; @@ -142,7 +1708,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const float _scale = scale == 0.f ? 1.f / sqrtf((float)embed_dim) : scale; const int BLOCK_M = 64; - const int BLOCK_N = 64; + const int BLOCK_N = 128; Mat& top_blob = top_blobs[0]; top_blob.create(out_embed_dim, src_seqlen, num_heads, 4u, opt.blob_allocator); @@ -239,8 +1805,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - float m_vec[64]; - float l_vec[64]; + float m_vec[BLOCK_M]; + float l_vec[BLOCK_M]; for (int i = 0; i < block_m; i++) { m_vec[i] = -FLT_MAX; @@ -348,57 +1914,55 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } #endif // NCNN_INT8 - Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat s_vec(BLOCK_N, opt.num_threads, 4u, opt.workspace_allocator); + // FP32 optimized path using tiled GEMM + online softmax + Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - if (o_accum.empty() || s_vec.empty()) + if (s_vec.empty() || p_vec.empty()) return -100; #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) + for (int g = 0; g < num_group; g++) { - const Mat query_head = query.channel(q); - const Mat key_head = key.channel(q / num_heads_per_group); - const Mat value_head = value.channel(q / num_heads_per_group); - Mat top_blob_head = top_blob.channel(q); + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); - Mat mask_head; - if (attn_mask) + for (int hq = 0; hq < num_heads_per_group; hq++) { - const Mat& maskm = attn_mask_blob; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + Mat mask_head; + if (attn_mask) { - mask_head = maskm; + const Mat& maskm = attn_mask_blob; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } } - } - Mat o_accum_head = o_accum.channel(get_omp_thread_num()); - float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - - for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) - { - int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; - int block_m = m_end - m_start; + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); - for (int i = 0; i < block_m; i++) + // Allocate per-row m_vec/l_vec on stack or heap + float m_vec_all[1024]; + float l_vec_all[1024]; + for (int i = 0; i < src_seqlen; i++) { - float* optr = o_accum_head.row(i); - for (int k = 0; k < out_embed_dim; k++) - { - optr[k] = 0.f; - } + m_vec_all[i] = -FLT_MAX; + l_vec_all[i] = 0.f; } - float m_vec[64]; - float l_vec[64]; - for (int i = 0; i < block_m; i++) + // Zero output accumulator + for (int i = 0; i < src_seqlen; i++) { - m_vec[i] = -FLT_MAX; - l_vec[i] = 0.f; + vec_zero_dispatch(top_blob_head.row(i), out_embed_dim); } for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) @@ -406,70 +1970,99 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; - for (int i = 0; i < block_m; i++) + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) { - const float* qptr = query_head.row(m_start + i); + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; - for (int j = 0; j < block_n; j++) - { - const float* kptr = key_head.row(n_start + j); - float sum = 0.f; - for (int k = 0; k < embed_dim; k++) - { - sum += qptr[k] * kptr[k]; - } - s_vec_ptr[j] = sum * _scale; - } + // Step 1: Compute Q * K^T -> S tile + qk_gemm_dispatch(s_vec_ptr, + query_head.row(m_start), + key_head.row(n_start), + block_m, block_n, embed_dim, _scale); + // Step 2: Apply attention mask if (attn_mask) { - const float* mptr = mask_head.row(m_start + i) + n_start; - for (int j = 0; j < block_n; j++) + for (int i = 0; i < block_m; i++) { - s_vec_ptr[j] += mptr[j]; + const float* mptr = mask_head.row(m_start + i) + n_start; + float* sptr = s_vec_ptr + i * block_n; + int j = 0; +#if __AVX512F__ + for (; j + 15 < block_n; j += 16) + { + _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); + } +#elif __AVX__ + for (; j + 7 < block_n; j += 8) + { + _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); + } +#elif __SSE2__ + for (; j + 3 < block_n; j += 4) + { + _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); + } +#endif + for (; j < block_n; j++) + { + sptr[j] += mptr[j]; + } } } - float m_new = m_vec[i]; - for (int j = 0; j < block_n; j++) - { - m_new = std::max(m_new, s_vec_ptr[j]); - } - - float scale_factor = expf(m_vec[i] - m_new); - float l_new = l_vec[i] * scale_factor; - - float* optr = o_accum_head.row(i); - for (int k = 0; k < out_embed_dim; k++) + // Step 3: Online softmax, compute P = exp(S - m_new) + float m_old[BLOCK_M]; + for (int i = 0; i < block_m; i++) { - optr[k] *= scale_factor; + m_old[i] = m_vec_all[m_start + i]; } + softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec_all + m_start, l_vec_all + m_start, block_m, block_n); - for (int j = 0; j < block_n; j++) + // Rescale O accumulator when max increases + for (int i = 0; i < block_m; i++) { - float p = expf(s_vec_ptr[j] - m_new); - l_new += p; - - const float* vptr = value_head.row(n_start + j); - for (int k = 0; k < out_embed_dim; k++) + float scale_factor = expf(m_old[i] - m_vec_all[m_start + i]); + if (scale_factor != 1.f) { - optr[k] += p * vptr[k]; + vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); } } - m_vec[i] = m_new; - l_vec[i] = l_new; + // Step 4: O += P * V_tile + pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); } } - for (int i = 0; i < block_m; i++) + // Normalize all rows + for (int i = 0; i < src_seqlen; i++) { - float* optr = o_accum_head.row(i); - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - for (int k = 0; k < out_embed_dim; k++) + float* outptr = top_blob_head.row(i); + float inv_l = 1.f / l_vec_all[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + { + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); + } +#elif __AVX__ + __m256 vinv_l = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + { + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); + } +#elif __SSE2__ + __m128 vinv_l = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + { + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); + } +#endif + for (; k < out_embed_dim; k++) { - outptr[k] = optr[k] * inv_l; + outptr[k] *= inv_l; } } } From 92822797bca60309fa7d77cd434ff8ae3eb7cc71 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 01:22:30 +0800 Subject: [PATCH 03/53] avx512 tail opt --- src/layer/x86/sdpa_x86.cpp | 145 ++++++++++++++++++++++--------------- 1 file changed, 88 insertions(+), 57 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index f2a64b54d9aa..46220ff069f8 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -182,24 +182,20 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); } - float sum0 = _mm512_comp_reduce_add_ps(acc0); - float sum1 = _mm512_comp_reduce_add_ps(acc1); - float sum2 = _mm512_comp_reduce_add_ps(acc2); - float sum3 = _mm512_comp_reduce_add_ps(acc3); - - for (; k < d; k++) + if (k < d) { - float qv = qptr[k]; - sum0 += qv * k0[k]; - sum1 += qv * k1[k]; - sum2 += qv * k2[k]; - sum3 += qv * k3[k]; + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k3 + k), acc3); } - S[(i + mi) * n + j + 0] = sum0 * scale; - S[(i + mi) * n + j + 1] = sum1 * scale; - S[(i + mi) * n + j + 2] = sum2 * scale; - S[(i + mi) * n + j + 3] = sum3 * scale; + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; } } @@ -214,10 +210,12 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, __m512 vacc = _mm512_setzero_ps(); for (; k + 15 < d; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - sum = _mm512_comp_reduce_add_ps(vacc); - for (; k < d; k++) - sum += qptr[k] * kptr[k]; - S[(i + mi) * n + j] = sum * scale; + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), _mm512_maskz_loadu_ps(mask, kptr + k), vacc); + } + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } } } @@ -248,24 +246,20 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); } - float sum0 = _mm512_comp_reduce_add_ps(acc0); - float sum1 = _mm512_comp_reduce_add_ps(acc1); - float sum2 = _mm512_comp_reduce_add_ps(acc2); - float sum3 = _mm512_comp_reduce_add_ps(acc3); - - for (; k < d; k++) + if (k < d) { - float qv = qptr[k]; - sum0 += qv * k0[k]; - sum1 += qv * k1[k]; - sum2 += qv * k2[k]; - sum3 += qv * k3[k]; + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k3 + k), acc3); } - S[i * n + j + 0] = sum0 * scale; - S[i * n + j + 1] = sum1 * scale; - S[i * n + j + 2] = sum2 * scale; - S[i * n + j + 3] = sum3 * scale; + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; } for (; j < n; j++) @@ -277,10 +271,12 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, __m512 vacc = _mm512_setzero_ps(); for (; k + 15 < d; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - sum = _mm512_comp_reduce_add_ps(vacc); - for (; k < d; k++) - sum += qptr[k] * kptr[k]; - S[i * n + j] = sum * scale; + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), _mm512_maskz_loadu_ps(mask, kptr + k), vacc); + } + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } } } @@ -461,14 +457,18 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int _mm512_storeu_ps(op[mi] + dd, acc[mi]); } - for (; dd < d; dd++) + if (dd < d) { + __mmask16 mask = (__mmask16)((1u << (d - dd)) - 1); for (int mi = 0; mi < M_BLOCK; mi++) { - float acc = op[mi][dd]; + __m512 acc = _mm512_maskz_loadu_ps(mask, op[mi] + dd); for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * d + dd]; - op[mi][dd] = acc; + { + __m512 vvec = _mm512_maskz_loadu_ps(mask, V + j * d + dd); + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc); + } + _mm512_mask_storeu_ps(op[mi] + dd, mask, acc); } } } @@ -504,19 +504,23 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int _mm512_storeu_ps(optr + dd, acc); } - for (; dd < d; dd++) + if (dd < d) { - float acc = optr[dd]; + __mmask16 mask = (__mmask16)((1u << (d - dd)) - 1); + __m512 acc = _mm512_maskz_loadu_ps(mask, optr + dd); for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * d + dd]; - optr[dd] = acc; + { + __m512 vvec = _mm512_maskz_loadu_ps(mask, V + j * d + dd); + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), vvec, acc); + } + _mm512_mask_storeu_ps(optr + dd, mask, acc); } } } template -static __attribute__((noinline)) void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) +static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) { const int VEC_PER_D = D / 16; int i = 0; @@ -684,9 +688,13 @@ static inline void softmax_tile_avx512(float* P, const float* S, int j = 0; for (; j + 15 < n; j += 16) vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); + vmax = _mm512_max_ps(vmax, tail); + } float m_new = _mm512_comp_reduce_max_ps(vmax); - for (; j < n; j++) - m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); l_vec[i] *= scale_factor; @@ -701,12 +709,15 @@ static inline void softmax_tile_avx512(float* P, const float* S, _mm512_storeu_ps(pptr + j, evec); vsum = _mm512_add_ps(vsum, evec); } - float l_add = _mm512_comp_reduce_add_ps(vsum); - for (; j < n; j++) + if (j < n) { - pptr[j] = expf(sptr[j] - m_new); - l_add += pptr[j]; + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_mask_storeu_ps(pptr + j, mask, evec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); } + float l_add = _mm512_comp_reduce_add_ps(vsum); l_vec[i] += l_add; m_vec[i] = m_new; } @@ -718,8 +729,11 @@ static inline void vec_scale_avx512(float* x, float s, int n) int i = 0; for (; i + 15 < n; i += 16) _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); - for (; i < n; i++) - x[i] *= s; + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); + } } static inline void vec_zero_avx512(float* x, int n) @@ -728,8 +742,11 @@ static inline void vec_zero_avx512(float* x, int n) int i = 0; for (; i + 15 < n; i += 16) _mm512_storeu_ps(x + i, zero); - for (; i < n; i++) - x[i] = 0.f; + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, zero); + } } #endif // __AVX512F__ @@ -1994,6 +2011,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); + _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); + j = block_n; + } #elif __AVX__ for (; j + 7 < block_n; j += 8) { @@ -2047,6 +2072,12 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); } + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); + k = out_embed_dim; + } #elif __AVX__ __m256 vinv_l = _mm256_set1_ps(inv_l); for (; k + 7 < out_embed_dim; k += 8) From 5abaa8483d9a079a3f52f604d896a81c82d85443 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 01:41:41 +0800 Subject: [PATCH 04/53] gemv for decode --- src/layer/x86/sdpa_x86.cpp | 504 ++++++++++++++++++++++++++++++++++++ tests/test_sdpa.cpp | 8 +- tests/test_sdpa_kvcache.cpp | 8 +- 3 files changed, 518 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 46220ff069f8..40bd2b3eb30e 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -126,6 +126,68 @@ static inline void vec_zero_scalar(float* x, int n) for (int i = 0; i < n; i++) x[i] = 0.f; } +static void sdpa_decode_scalar(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; + float s[BLOCK_N]; + + for (int k = 0; k < out_d; k++) out[k] = 0.f; + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + + for (int j = 0; j < block_n; j++) + { + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } + + if (mask) + { + for (int j = 0; j < block_n; j++) + s[j] += mask[n_start + j]; + } + + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); + + float new_m = std::max(m, tile_m); + float scale_factor = expf(m - new_m); + l *= scale_factor; + for (int k = 0; k < out_d; k++) + out[k] *= scale_factor; + + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; + + for (int j = 0; j < block_n; j++) + { + for (int k = 0; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; + } + + m = new_m; + } + + float inv_l = 1.f / l; + for (int k = 0; k < out_d; k++) + out[k] *= inv_l; +} + static inline void softmax_tile_scalar(float* P, const float* S, float* m_vec, float* l_vec, int m, int n) { @@ -749,6 +811,157 @@ static inline void vec_zero_avx512(float* x, int n) } } +static inline void decode_qk_dot_avx512(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +} + +static inline void decode_pv_gemv_avx512(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + __m512 pvec = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 15 < out_d; k += 16) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + } + } +} + +static void sdpa_decode_avx512(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; + __attribute__((aligned(64))) float s[BLOCK_N]; + + vec_zero_avx512(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + + decode_qk_dot_avx512(s, q, K, n_start, block_n, d, scale); + + if (mask) + { + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + } + } + + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); + + float new_m = std::max(m, tile_m); + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_avx512(out, scale_factor, out_d); + + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); + + decode_pv_gemv_avx512(out, s, V, n_start, block_n, out_d); + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale_avx512(out, inv_l, out_d); +} + #endif // __AVX512F__ #if __AVX__ @@ -1162,6 +1375,135 @@ static inline void vec_zero_avx(float* x, int n) x[i] = 0.f; } +static inline void decode_qk_dot_avx(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + sum0 += q[k] * k0[k]; + sum1 += q[k] * k1[k]; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + + for (; j < block_n; j++) + { + const float* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); + int k = 0; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * kptr[k]; + s[j] = sum * scale; + } +} + +static inline void decode_pv_gemv_avx(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + __m256 pvec = _mm256_set1_ps(s[j]); + int k = 0; + for (; k + 7 < out_d; k += 8) + { + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; + } +} + +static void sdpa_decode_avx(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; + __attribute__((aligned(32))) float s[BLOCK_N]; + + vec_zero_avx(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + + decode_qk_dot_avx(s, q, K, n_start, block_n, d, scale); + + if (mask) + { + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; + } + + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); + + float new_m = std::max(m, tile_m); + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_avx(out, scale_factor, out_d); + + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; + + decode_pv_gemv_avx(out, s, V, n_start, block_n, out_d); + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale_avx(out, inv_l, out_d); +} + #endif // __AVX__ #if __SSE2__ @@ -1496,6 +1838,104 @@ static inline void vec_zero_sse2(float* x, int n) x[i] = 0.f; } +static inline void decode_qk_dot_sse2(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +{ + for (int j = 0; j < block_n; j++) + { + __m128 acc = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), _mm_loadu_ps(K + (n_start + j) * d + k), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } +} + +static inline void decode_pv_gemv_sse2(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + __m128 pvec = _mm_set1_ps(s[j]); + int k = 0; + for (; k + 3 < out_d; k += 4) + { + __m128 oval = _mm_loadu_ps(out + k); + __m128 vval = _mm_loadu_ps(V + (n_start + j) * out_d + k); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; + } +} + +static void sdpa_decode_sse2(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; + __attribute__((aligned(16))) float s[BLOCK_N]; + + vec_zero_sse2(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + + decode_qk_dot_sse2(s, q, K, n_start, block_n, d, scale); + + if (mask) + { + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; + } + + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float tile_m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); + + float new_m = std::max(m, tile_m); + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_sse2(out, scale_factor, out_d); + + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; + + decode_pv_gemv_sse2(out, s, V, n_start, block_n, out_d); + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale_sse2(out, inv_l, out_d); +} + #endif // __SSE2__ @@ -1647,6 +2087,21 @@ static inline void softmax_tile_dispatch(float* P, const float* S, #endif } +static inline void sdpa_decode_dispatch(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ +#if __AVX512F__ + sdpa_decode_avx512(out, q, K, V, mask, n, d, out_d, scale); +#elif __AVX__ + sdpa_decode_avx(out, q, K, V, mask, n, d, out_d, scale); +#elif __SSE2__ + sdpa_decode_sse2(out, q, K, V, mask, n, d, out_d, scale); +#else + sdpa_decode_scalar(out, q, K, V, mask, n, d, out_d, scale); +#endif +} + // Timing instrumentation removed int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const @@ -1932,6 +2387,55 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to #endif // NCNN_INT8 // FP32 optimized path using tiled GEMM + online softmax + if (src_seqlen == 1) + { + // Decode path: fused GEMV kernel for single-query attention + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; + } + Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); diff --git a/tests/test_sdpa.cpp b/tests/test_sdpa.cpp index 04cb92ae2a24..f25633b435ab 100644 --- a/tests/test_sdpa.cpp +++ b/tests/test_sdpa.cpp @@ -47,7 +47,13 @@ static int test_sdpa_0() || test_sdpa(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f) || test_sdpa(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f) || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f) - || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f); + || test_sdpa(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f) + || test_sdpa(RandomMat(32, 1, 8), RandomMat(32, 66, 8), RandomMat(20, 66, 8), 0) + || test_sdpa(RandomMat(26, 1, 8), RandomMat(26, 61, 8), RandomMat(18, 61, 8), 1) + || test_sdpa(RandomMat(64, 1, 12), RandomMat(64, 128, 2), RandomMat(64, 128, 2), 0) + || test_sdpa(RandomMat(64, 1, 12), RandomMat(64, 127, 2), RandomMat(48, 127, 2), 1) + || test_sdpa(RandomMat(44, 1, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 1.f) + || test_sdpa(RandomMat(12, 1, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f); } #if NCNN_INT8 diff --git a/tests/test_sdpa_kvcache.cpp b/tests/test_sdpa_kvcache.cpp index 3e0a20221ad0..0ad6da25eb7d 100644 --- a/tests/test_sdpa_kvcache.cpp +++ b/tests/test_sdpa_kvcache.cpp @@ -53,7 +53,13 @@ static int test_sdpa_0() || test_sdpa_kvcache(RandomMat(44, 128, 4), RandomMat(44, 123, 4), RandomMat(55, 123, 4), 0, 0) || test_sdpa_kvcache(RandomMat(12, 127, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 0) || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 3) - || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5); + || test_sdpa_kvcache(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, 5) + || test_sdpa_kvcache(RandomMat(32, 1, 8), RandomMat(32, 1, 8), RandomMat(20, 1, 8), 0, 11) + || test_sdpa_kvcache(RandomMat(26, 1, 8), RandomMat(26, 1, 8), RandomMat(18, 1, 8), 1, 11) + || test_sdpa_kvcache(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(64, 1, 2), 0, 1) + || test_sdpa_kvcache(RandomMat(64, 1, 12), RandomMat(64, 1, 2), RandomMat(48, 1, 2), 1, 1) + || test_sdpa_kvcache(RandomMat(44, 1, 4), RandomMat(44, 1, 4), RandomMat(55, 1, 4), 0, 0) + || test_sdpa_kvcache(RandomMat(12, 1, 4), RandomMat(12, 1, 4), RandomMat(55, 1, 4), 1, 0); } #if NCNN_INT8 From 42d331da131af58868304b19a92ab05e45f10e1c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 01:55:37 +0800 Subject: [PATCH 05/53] permute loop --- src/layer/x86/sdpa_x86.cpp | 106 ++++++++++++++++++------------------- 1 file changed, 53 insertions(+), 53 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 40bd2b3eb30e..63668182a93d 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -2471,30 +2471,30 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* s_vec_ptr = s_vec.row(get_omp_thread_num()); float* p_vec_ptr = p_vec.row(get_omp_thread_num()); - // Allocate per-row m_vec/l_vec on stack or heap - float m_vec_all[1024]; - float l_vec_all[1024]; - for (int i = 0; i < src_seqlen; i++) + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) { - m_vec_all[i] = -FLT_MAX; - l_vec_all[i] = 0.f; - } + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; - // Zero output accumulator - for (int i = 0; i < src_seqlen; i++) - { - vec_zero_dispatch(top_blob_head.row(i), out_embed_dim); - } + // Per-M-tile softmax statistics + float m_vec[BLOCK_M]; + float l_vec[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; - int block_n = n_end - n_start; + // Zero output accumulator for this M tile + for (int i = 0; i < block_m; i++) + { + vec_zero_dispatch(top_blob_head.row(m_start + i), out_embed_dim); + } - for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) { - int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; - int block_m = m_end - m_start; + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; // Step 1: Compute Q * K^T -> S tile qk_gemm_dispatch(s_vec_ptr, @@ -2545,14 +2545,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float m_old[BLOCK_M]; for (int i = 0; i < block_m; i++) { - m_old[i] = m_vec_all[m_start + i]; + m_old[i] = m_vec[i]; } - softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec_all + m_start, l_vec_all + m_start, block_m, block_n); + softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec, l_vec, block_m, block_n); // Rescale O accumulator when max increases for (int i = 0; i < block_m; i++) { - float scale_factor = expf(m_old[i] - m_vec_all[m_start + i]); + float scale_factor = expf(m_old[i] - m_vec[i]); if (scale_factor != 1.f) { vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); @@ -2562,42 +2562,42 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // Step 4: O += P * V_tile pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); } - } - // Normalize all rows - for (int i = 0; i < src_seqlen; i++) - { - float* outptr = top_blob_head.row(i); - float inv_l = 1.f / l_vec_all[i]; - int k = 0; -#if __AVX512F__ - __m512 vinv_l = _mm512_set1_ps(inv_l); - for (; k + 15 < out_embed_dim; k += 16) - { - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); - } - if (k < out_embed_dim) + // Normalize this M tile + for (int i = 0; i < block_m; i++) { - __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); - k = out_embed_dim; - } + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + { + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); + } + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); + k = out_embed_dim; + } #elif __AVX__ - __m256 vinv_l = _mm256_set1_ps(inv_l); - for (; k + 7 < out_embed_dim; k += 8) - { - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); - } + __m256 vinv_l = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + { + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); + } #elif __SSE2__ - __m128 vinv_l = _mm_set1_ps(inv_l); - for (; k + 3 < out_embed_dim; k += 4) - { - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); - } + __m128 vinv_l = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + { + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); + } #endif - for (; k < out_embed_dim; k++) - { - outptr[k] *= inv_l; + for (; k < out_embed_dim; k++) + { + outptr[k] *= inv_l; + } } } } From 85c17e697d74290da2456d44b55e1e5a811fc792 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 01:57:35 +0800 Subject: [PATCH 06/53] opt softmax --- src/layer/x86/sdpa_x86.cpp | 42 ++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 15 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 63668182a93d..5cfe738a0476 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -161,10 +161,13 @@ static void sdpa_decode_scalar(float* out, const float* q, tile_m = std::max(tile_m, s[j]); float new_m = std::max(m, tile_m); - float scale_factor = expf(m - new_m); - l *= scale_factor; - for (int k = 0; k < out_d; k++) - out[k] *= scale_factor; + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + for (int k = 0; k < out_d; k++) + out[k] *= scale_factor; + } float l_add = 0.f; for (int j = 0; j < block_n; j++) @@ -931,9 +934,12 @@ static void sdpa_decode_avx512(float* out, const float* q, float tile_m = _mm512_comp_reduce_max_ps(vmax); float new_m = std::max(m, tile_m); - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_avx512(out, scale_factor, out_d); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_avx512(out, scale_factor, out_d); + } __m512 vm_new = _mm512_set1_ps(new_m); __m512 vsum = _mm512_setzero_ps(); @@ -1474,9 +1480,12 @@ static void sdpa_decode_avx(float* out, const float* q, tile_m = std::max(tile_m, s[j]); float new_m = std::max(m, tile_m); - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_avx(out, scale_factor, out_d); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_avx(out, scale_factor, out_d); + } __m256 vm_new = _mm256_set1_ps(new_m); __m256 vsum = _mm256_setzero_ps(); @@ -1906,9 +1915,12 @@ static void sdpa_decode_sse2(float* out, const float* q, tile_m = std::max(tile_m, s[j]); float new_m = std::max(m, tile_m); - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_sse2(out, scale_factor, out_d); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_sse2(out, scale_factor, out_d); + } __m128 vm_new = _mm_set1_ps(new_m); __m128 vsum = _mm_setzero_ps(); @@ -2552,9 +2564,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // Rescale O accumulator when max increases for (int i = 0; i < block_m; i++) { - float scale_factor = expf(m_old[i] - m_vec[i]); - if (scale_factor != 1.f) + if (m_old[i] != m_vec[i]) { + float scale_factor = expf(m_old[i] - m_vec[i]); vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); } } From 2269ed69be79419de69436bac3082a02f4655b1a Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 02:13:09 +0800 Subject: [PATCH 07/53] opt --- src/layer/x86/sdpa_x86.cpp | 261 ++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 117 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 5cfe738a0476..217250a15161 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -2454,162 +2454,189 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (s_vec.empty() || p_vec.empty()) return -100; + int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; + + // Per-head per-M-tile softmax state for cross-N-tile accumulation + Mat m_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + Mat l_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + + if (m_state.empty() || l_state.empty()) + return -100; + #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) + for (int idx = 0; idx < num_group * num_m_tiles; idx++) { + int g = idx / num_m_tiles; + int m_tile = idx % num_m_tiles; + int m_start = m_tile * BLOCK_M; + int block_m = m_start + BLOCK_M < src_seqlen ? BLOCK_M : src_seqlen - m_start; + const Mat key_head = key.channel(g); const Mat value_head = value.channel(g); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); + + Mat m_state_tile = m_state.channel(idx); + Mat l_state_tile = l_state.channel(idx); + + // Initialize softmax state and zero output accumulator for all Q heads in this group for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); Mat top_blob_head = top_blob.channel(q); - Mat mask_head; - if (attn_mask) + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + for (int i = 0; i < block_m; i++) { - const Mat& maskm = attn_mask_blob; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; } - float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - float* p_vec_ptr = p_vec.row(get_omp_thread_num()); - - for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + for (int i = 0; i < block_m; i++) { - int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; - int block_m = m_end - m_start; + vec_zero_dispatch(top_blob_head.row(m_start + i), out_embed_dim); + } + } - // Per-M-tile softmax statistics - float m_vec[BLOCK_M]; - float l_vec[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_vec[i] = -FLT_MAX; - l_vec[i] = 0.f; - } + // N-outer loop: each K/V N-tile is loaded once and reused by all Q heads in this group + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; - // Zero output accumulator for this M tile - for (int i = 0; i < block_m; i++) + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + Mat mask_head; + if (attn_mask) { - vec_zero_dispatch(top_blob_head.row(m_start + i), out_embed_dim); + const Mat& maskm = attn_mask_blob; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } } - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; - int block_n = n_end - n_start; + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); - // Step 1: Compute Q * K^T -> S tile - qk_gemm_dispatch(s_vec_ptr, - query_head.row(m_start), - key_head.row(n_start), - block_m, block_n, embed_dim, _scale); + // Step 1: Compute Q * K^T -> S tile + qk_gemm_dispatch(s_vec_ptr, + query_head.row(m_start), + key_head.row(n_start), + block_m, block_n, embed_dim, _scale); - // Step 2: Apply attention mask - if (attn_mask) + // Step 2: Apply attention mask + if (attn_mask) + { + for (int i = 0; i < block_m; i++) { - for (int i = 0; i < block_m; i++) - { - const float* mptr = mask_head.row(m_start + i) + n_start; - float* sptr = s_vec_ptr + i * block_n; - int j = 0; + const float* mptr = mask_head.row(m_start + i) + n_start; + float* sptr = s_vec_ptr + i * block_n; + int j = 0; #if __AVX512F__ - for (; j + 15 < block_n; j += 16) - { - _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); - } - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); - _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); - j = block_n; - } + for (; j + 15 < block_n; j += 16) + { + _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); + _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); + j = block_n; + } #elif __AVX__ - for (; j + 7 < block_n; j += 8) - { - _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); - } + for (; j + 7 < block_n; j += 8) + { + _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); + } #elif __SSE2__ - for (; j + 3 < block_n; j += 4) - { - _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); - } -#endif - for (; j < block_n; j++) - { - sptr[j] += mptr[j]; - } + for (; j + 3 < block_n; j += 4) + { + _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); } - } - - // Step 3: Online softmax, compute P = exp(S - m_new) - float m_old[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec, l_vec, block_m, block_n); - - // Rescale O accumulator when max increases - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) +#endif + for (; j < block_n; j++) { - float scale_factor = expf(m_old[i] - m_vec[i]); - vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); + sptr[j] += mptr[j]; } } + } - // Step 4: O += P * V_tile - pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); + // Step 3: Online softmax, compute P = exp(S - m_new) + float m_old[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; } + softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec, l_vec, block_m, block_n); - // Normalize this M tile + // Rescale O accumulator when max increases for (int i = 0; i < block_m; i++) { - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - int k = 0; -#if __AVX512F__ - __m512 vinv_l = _mm512_set1_ps(inv_l); - for (; k + 15 < out_embed_dim; k += 16) + if (m_old[i] != m_vec[i]) { - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); - } - if (k < out_embed_dim) - { - __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); - k = out_embed_dim; + float scale_factor = expf(m_old[i] - m_vec[i]); + vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); } + } + + // Step 4: O += P * V_tile + pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); + } + } + + // Normalize all Q heads for this M tile + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + Mat top_blob_head = top_blob.channel(q); + float* l_vec = l_state_tile.row(hq); + + for (int i = 0; i < block_m; i++) + { + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + { + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); + } + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); + k = out_embed_dim; + } #elif __AVX__ - __m256 vinv_l = _mm256_set1_ps(inv_l); - for (; k + 7 < out_embed_dim; k += 8) - { - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); - } + __m256 vinv_l = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + { + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); + } #elif __SSE2__ - __m128 vinv_l = _mm_set1_ps(inv_l); - for (; k + 3 < out_embed_dim; k += 4) - { - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); - } + __m128 vinv_l = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + { + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); + } #endif - for (; k < out_embed_dim; k++) - { - outptr[k] *= inv_l; - } + for (; k < out_embed_dim; k++) + { + outptr[k] *= inv_l; } } } From 7c29bdad37556a15234be072f48cee662568be62 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 02:14:25 +0800 Subject: [PATCH 08/53] disable bf16s --- src/layer/x86/sdpa_x86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 217250a15161..e0ae7b05f3ae 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -30,7 +30,7 @@ namespace ncnn { SDPA_x86::SDPA_x86() { #if NCNN_BF16 - support_bf16_storage = true; + support_bf16_storage = false; #endif } From d12e3460368acd32a875185bf33b07cd75ddd939 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 02:33:29 +0800 Subject: [PATCH 09/53] fix multithread --- src/layer/x86/sdpa_x86.cpp | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e0ae7b05f3ae..5f2844f2ba7d 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -2474,8 +2474,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat key_head = key.channel(g); const Mat value_head = value.channel(g); - float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - float* p_vec_ptr = p_vec.row(get_omp_thread_num()); + float* s_vec_ptr = s_vec.channel(get_omp_thread_num()).row(0); + float* p_vec_ptr = p_vec.channel(get_omp_thread_num()).row(0); Mat m_state_tile = m_state.channel(idx); Mat l_state_tile = l_state.channel(idx); @@ -2500,18 +2500,23 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - // N-outer loop: each K/V N-tile is loaded once and reused by all Q heads in this group + // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) { int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; + // Process one Q head at a time with a single-head scratch buffer, + // keeping K/V tile resident in cache across hq iterations. for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; const Mat query_head = query.channel(q); Mat top_blob_head = top_blob.channel(q); + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + Mat mask_head; if (attn_mask) { @@ -2526,16 +2531,12 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - - // Step 1: Compute Q * K^T -> S tile qk_gemm_dispatch(s_vec_ptr, query_head.row(m_start), key_head.row(n_start), block_m, block_n, embed_dim, _scale); - // Step 2: Apply attention mask + // Apply attention mask if (attn_mask) { for (int i = 0; i < block_m; i++) @@ -2574,7 +2575,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - // Step 3: Online softmax, compute P = exp(S - m_new) float m_old[BLOCK_M]; for (int i = 0; i < block_m; i++) { @@ -2592,7 +2592,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - // Step 4: O += P * V_tile pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); } } From b39eb23874e8db8417baba357ed2030c5e215fba Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 02:42:00 +0800 Subject: [PATCH 10/53] fix cache --- src/layer/x86/sdpa_x86.cpp | 250 +++++++++++++++++++++++++++++-------- 1 file changed, 201 insertions(+), 49 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 5f2844f2ba7d..dfa8f4abad96 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -219,69 +219,173 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, int m, int n, int d, float scale) { int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = _mm512_maskz_loadu_ps(mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(mask, k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = _mm512_maskz_loadu_ps(mask, kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + for (; i + 4 <= m; i += 4) { int j = 0; - for (; j + 4 <= n; j += 4) + for (; j + 2 <= n; j += 2) { const float* k0 = K + (j + 0) * d; const float* k1 = K + (j + 1) * d; - const float* k2 = K + (j + 2) * d; - const float* k3 = K + (j + 3) * d; + __m512 acc[4][2]; for (int mi = 0; mi < 4; mi++) { - const float* qptr = Q + (i + mi) * d; - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } - int k = 0; - for (; k + 15 < d; k += 16) + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } + } - if (k < d) + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = _mm512_maskz_loadu_ps(mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(mask, k1 + k); + + for (int mi = 0; mi < 4; mi++) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k3 + k), acc3); + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } + } - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; } } for (; j < n; j++) { + const float* kptr = K + j * d; + + __m512 acc[4]; for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) { - const float* qptr = Q + (i + mi) * d; - const float* kptr = K + j * d; - float sum = 0.f; - int k = 0; - __m512 vacc = _mm512_setzero_ps(); - for (; k + 15 < d; k += 16) - vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - if (k < d) + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), _mm512_maskz_loadu_ps(mask, kptr + k), vacc); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = _mm512_maskz_loadu_ps(mask, kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } @@ -352,39 +456,89 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl int m, int n, float scale) { int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + for (; i + 4 <= m; i += 4) { int j = 0; - for (; j + 4 <= n; j += 4) + for (; j + 2 <= n; j += 2) { const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; - const float* k2 = K + (j + 2) * D; - const float* k3 = K + (j + 3) * D; - __m512 acc[4][4]; + __m512 acc[4][2]; for (int mi = 0; mi < 4; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); - acc[mi][2] = _mm512_setzero_ps(); - acc[mi][3] = _mm512_setzero_ps(); } for (int k = 0; k < D; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); - __m512 kv2 = _mm512_loadu_ps(k2 + k); - __m512 kv3 = _mm512_loadu_ps(k3 + k); for (int mi = 0; mi < 4; mi++) { __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); } } @@ -392,8 +546,6 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; } } From faec20bd608ed5682907a3fdf7712ebe6f3cce2e Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 10:07:43 +0800 Subject: [PATCH 11/53] opt cache --- src/layer/x86/sdpa_x86.cpp | 514 +++++++++++++++++++------------------ 1 file changed, 265 insertions(+), 249 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index dfa8f4abad96..c950f52d7025 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -192,7 +192,7 @@ static void sdpa_decode_scalar(float* out, const float* q, } static inline void softmax_tile_scalar(float* P, const float* S, - float* m_vec, float* l_vec, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -201,6 +201,7 @@ static inline void softmax_tile_scalar(float* P, const float* S, float m_new = m_vec[i]; for (int j = 0; j < n; j++) m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; l_vec[i] *= scale_factor; float l_add = 0.f; for (int j = 0; j < n; j++) @@ -619,118 +620,128 @@ template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 16; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - op[mi] = O + (i + mi) * d; - pptr[mi] = P + (i + mi) * n; - } + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { __m512 acc[M_BLOCK][VEC_PER_UNROLL]; for (int mi = 0; mi < M_BLOCK; mi++) for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm512_loadu_ps(op[mi] + dd + vi * 16); + acc[mi][vi] = _mm512_loadu_ps(op[mi] + vi * 16); for (int j = 0; j < n; j++) { - __m512 pvec[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - pvec[mi] = _mm512_set1_ps(pptr[mi][j]); - + __m512 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) { - __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi][vi] = _mm512_fmadd_ps(pvec[mi], vvec, acc[mi][vi]); + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm512_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); } } for (int mi = 0; mi < M_BLOCK; mi++) for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm512_storeu_ps(op[mi] + dd + vi * 16, acc[mi][vi]); + _mm512_storeu_ps(op[mi] + vi * 16, acc[mi][vi]); } - for (; dd + 15 < d; dd += 16) + for (; i < m; i++) { - __m512 acc[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_loadu_ps(op[mi] + dd); + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + + __m512 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_loadu_ps(optr + vi * 16); for (int j = 0; j < n; j++) { - __m512 vvec = _mm512_loadu_ps(V + j * d + dd); - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + __m512 pvec = _mm512_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); } - for (int mi = 0; mi < M_BLOCK; mi++) - _mm512_storeu_ps(op[mi] + dd, acc[mi]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(optr + vi * 16, acc[vi]); } + } - if (dd < d) + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - __mmask16 mask = (__mmask16)((1u << (d - dd)) - 1); + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; for (int mi = 0; mi < M_BLOCK; mi++) { - __m512 acc = _mm512_maskz_loadu_ps(mask, op[mi] + dd); - for (int j = 0; j < n; j++) - { - __m512 vvec = _mm512_maskz_loadu_ps(mask, V + j * d + dd); - acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc); - } - _mm512_mask_storeu_ps(op[mi] + dd, mask, acc); + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; } - } - } - - for (; i < m; i++) - { - float* optr = O + i * d; - const float* pptr = P + i * n; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { - __m512 acc[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm512_loadu_ps(optr + dd + vi * 16); + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); for (int j = 0; j < n; j++) { - __m512 pvec = _mm512_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); + __m512 vvec = _mm512_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); } - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm512_storeu_ps(optr + dd + vi * 16, acc[vi]); + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); } - for (; dd + 15 < d; dd += 16) + for (; i < m; i++) { - __m512 acc = _mm512_loadu_ps(optr + dd); + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc = _mm512_loadu_ps(optr); for (int j = 0; j < n; j++) acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); - _mm512_storeu_ps(optr + dd, acc); + _mm512_storeu_ps(optr, acc); } + } - if (dd < d) + for (; dd < d; dd++) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - __mmask16 mask = (__mmask16)((1u << (d - dd)) - 1); - __m512 acc = _mm512_maskz_loadu_ps(mask, optr + dd); + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } for (int j = 0; j < n; j++) { - __m512 vvec = _mm512_maskz_loadu_ps(mask, V + j * d + dd); - acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), vvec, acc); + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] += pptr[mi][j] * V[j * d + dd]; } - _mm512_mask_storeu_ps(optr + dd, mask, acc); + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + optr[0] += pptr[j] * V[j * d + dd]; } } } @@ -894,7 +905,7 @@ static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int static inline void softmax_tile_avx512(float* P, const float* S, - float* m_vec, float* l_vec, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -914,6 +925,7 @@ static inline void softmax_tile_avx512(float* P, const float* S, float m_new = _mm512_comp_reduce_max_ps(vmax); float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; l_vec[i] *= scale_factor; __m512 vm_new = _mm512_set1_ps(m_new); @@ -1321,25 +1333,25 @@ template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 8; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - op[mi] = O + (i + mi) * d; - pptr[mi] = P + (i + mi) * n; - } + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { for (int mi = 0; mi < M_BLOCK; mi++) { __m256 acc[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); for (int j = 0; j < n; j++) { @@ -1349,44 +1361,18 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, } for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[vi]); - } - } - - for (; dd + 7 < d; dd += 8) - { - for (int mi = 0; mi < M_BLOCK; mi++) - { - __m256 acc = _mm256_loadu_ps(op[mi] + dd); - for (int j = 0; j < n; j++) - acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), _mm256_loadu_ps(V + j * d + dd), acc); - _mm256_storeu_ps(op[mi] + dd, acc); + _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); } } - for (; dd < d; dd++) + for (; i < m; i++) { - for (int mi = 0; mi < M_BLOCK; mi++) - { - float acc = op[mi][dd]; - for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * d + dd]; - op[mi][dd] = acc; - } - } - } - - for (; i < m; i++) - { - float* optr = O + i * d; - const float* pptr = P + i * n; + float* optr = O + i * d + dd; + const float* pptr = P + i * n; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { __m256 acc[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); + acc[vi] = _mm256_loadu_ps(optr + vi * 8); for (int j = 0; j < n; j++) { @@ -1396,23 +1382,67 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, } for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); + _mm256_storeu_ps(optr + vi * 8, acc[vi]); + } + } + + for (; dd + 7 < d; dd += 8) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc = _mm256_loadu_ps(op[mi]); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(op[mi], acc); + } } - for (; dd + 7 < d; dd += 8) + for (; i < m; i++) { - __m256 acc = _mm256_loadu_ps(optr + dd); + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m256 acc = _mm256_loadu_ps(optr); for (int j = 0; j < n; j++) acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); - _mm256_storeu_ps(optr + dd, acc); + _mm256_storeu_ps(optr, acc); } + } - for (; dd < d; dd++) + for (; dd < d; dd++) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - float acc = optr[dd]; + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * d + dd]; - optr[dd] = acc; + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] += pptr[mi][j] * V[j * d + dd]; + } + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + optr[0] += pptr[j] * V[j * d + dd]; } } } @@ -1474,7 +1504,7 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, static inline void softmax_tile_avx(float* P, const float* S, - float* m_vec, float* l_vec, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -1490,6 +1520,7 @@ static inline void softmax_tile_avx(float* P, const float* S, m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; l_vec[i] *= scale_factor; __m256 vm_new = _mm256_set1_ps(m_new); @@ -1787,25 +1818,25 @@ template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 4; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - op[mi] = O + (i + mi) * d; - pptr[mi] = P + (i + mi) * n; - } + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { for (int mi = 0; mi < M_BLOCK; mi++) { __m128 acc[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_loadu_ps(op[mi] + dd + vi * 4); + acc[vi] = _mm_loadu_ps(op[mi] + vi * 4); for (int j = 0; j < n; j++) { @@ -1815,44 +1846,18 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, } for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm_storeu_ps(op[mi] + dd + vi * 4, acc[vi]); + _mm_storeu_ps(op[mi] + vi * 4, acc[vi]); } } - for (; dd + 3 < d; dd += 4) + for (; i < m; i++) { - for (int mi = 0; mi < M_BLOCK; mi++) - { - __m128 acc = _mm_loadu_ps(op[mi] + dd); - for (int j = 0; j < n; j++) - acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), _mm_loadu_ps(V + j * d + dd), acc); - _mm_storeu_ps(op[mi] + dd, acc); - } - } - - for (; dd < d; dd++) - { - for (int mi = 0; mi < M_BLOCK; mi++) - { - float acc = op[mi][dd]; - for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * d + dd]; - op[mi][dd] = acc; - } - } - } + float* optr = O + i * d + dd; + const float* pptr = P + i * n; - for (; i < m; i++) - { - float* optr = O + i * d; - const float* pptr = P + i * n; - - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { __m128 acc[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_loadu_ps(optr + dd + vi * 4); + acc[vi] = _mm_loadu_ps(optr + vi * 4); for (int j = 0; j < n; j++) { @@ -1862,23 +1867,67 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, } for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm_storeu_ps(optr + dd + vi * 4, acc[vi]); + _mm_storeu_ps(optr + vi * 4, acc[vi]); + } + } + + for (; dd + 3 < d; dd += 4) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m128 acc = _mm_loadu_ps(op[mi]); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), _mm_loadu_ps(V + j * d + dd), acc); + _mm_storeu_ps(op[mi], acc); + } } - for (; dd + 3 < d; dd += 4) + for (; i < m; i++) { - __m128 acc = _mm_loadu_ps(optr + dd); + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m128 acc = _mm_loadu_ps(optr); for (int j = 0; j < n; j++) acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), _mm_loadu_ps(V + j * d + dd), acc); - _mm_storeu_ps(optr + dd, acc); + _mm_storeu_ps(optr, acc); } + } - for (; dd < d; dd++) + for (; dd < d; dd++) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - float acc = optr[dd]; + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * d + dd]; - optr[dd] = acc; + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] += pptr[mi][j] * V[j * d + dd]; + } + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + optr[0] += pptr[j] * V[j * d + dd]; } } } @@ -1940,7 +1989,7 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, static inline void softmax_tile_sse2(float* P, const float* S, - float* m_vec, float* l_vec, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -1956,6 +2005,7 @@ static inline void softmax_tile_sse2(float* P, const float* S, m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; l_vec[i] *= scale_factor; __m128 vm_new = _mm_set1_ps(m_new); @@ -2160,52 +2210,12 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, int m, int n, int d) { - if (d == 128) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n); - return; -#elif __AVX__ - pv_gemm_avx<2, 128>(O, P, V, m, n); - return; -#elif __SSE2__ - pv_gemm_sse2<4, 128>(O, P, V, m, n); - return; -#endif - } - if (d == 64) - { -#if __AVX512F__ - pv_gemm_avx512<2, 64>(O, P, V, m, n); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n); - return; -#elif __SSE2__ - pv_gemm_sse2<4, 64>(O, P, V, m, n); - return; -#endif - } - if (d == 512) - { -#if __AVX512F__ - pv_gemm_avx512<2, 512>(O, P, V, m, n); - return; -#elif __AVX__ - pv_gemm_avx<2, 512>(O, P, V, m, n); - return; -#elif __SSE2__ - pv_gemm_sse2<4, 512>(O, P, V, m, n); - return; -#endif - } - #if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, d); + pv_gemm_avx512<4, 64>(O, P, V, m, n, d); #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<4, 32>(O, P, V, m, n, d); #elif __SSE2__ - pv_gemm_sse2<4, 4>(O, P, V, m, n, d); + pv_gemm_sse2<4, 16>(O, P, V, m, n, d); #else pv_gemm_scalar(O, P, V, m, n, d); #endif @@ -2238,16 +2248,16 @@ static inline void vec_zero_dispatch(float* x, int n) } static inline void softmax_tile_dispatch(float* P, const float* S, - float* m_vec, float* l_vec, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { #if __AVX512F__ - softmax_tile_avx512(P, S, m_vec, l_vec, m, n); + softmax_tile_avx512(P, S, m_vec, l_vec, scale_out, m, n); #elif __AVX__ - softmax_tile_avx(P, S, m_vec, l_vec, m, n); + softmax_tile_avx(P, S, m_vec, l_vec, scale_out, m, n); #elif __SSE2__ - softmax_tile_sse2(P, S, m_vec, l_vec, m, n); + softmax_tile_sse2(P, S, m_vec, l_vec, scale_out, m, n); #else - softmax_tile_scalar(P, S, m_vec, l_vec, m, n); + softmax_tile_scalar(P, S, m_vec, l_vec, scale_out, m, n); #endif } @@ -2600,8 +2610,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); if (s_vec.empty() || p_vec.empty()) return -100; @@ -2626,8 +2636,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat key_head = key.channel(g); const Mat value_head = value.channel(g); - float* s_vec_ptr = s_vec.channel(get_omp_thread_num()).row(0); - float* p_vec_ptr = p_vec.channel(get_omp_thread_num()).row(0); + Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); + Mat p_vec_thread = p_vec.channel(get_omp_thread_num()); Mat m_state_tile = m_state.channel(idx); Mat l_state_tile = l_state.channel(idx); @@ -2658,32 +2668,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; - // Process one Q head at a time with a single-head scratch buffer, - // keeping K/V tile resident in cache across hq iterations. + // Phase A: All heads compute QK GEMM while K tile is hot in cache for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; const Mat query_head = query.channel(q); - Mat top_blob_head = top_blob.channel(q); + float* s_ptr = s_vec_thread.row(hq); - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - - Mat mask_head; - if (attn_mask) - { - const Mat& maskm = attn_mask_blob; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - } - - qk_gemm_dispatch(s_vec_ptr, + qk_gemm_dispatch(s_ptr, query_head.row(m_start), key_head.row(n_start), block_m, block_n, embed_dim, _scale); @@ -2691,10 +2683,22 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // Apply attention mask if (attn_mask) { + Mat mask_head; + { + const Mat& maskm = attn_mask_blob; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + } for (int i = 0; i < block_m; i++) { const float* mptr = mask_head.row(m_start + i) + n_start; - float* sptr = s_vec_ptr + i * block_n; + float* sptr = s_ptr + i * block_n; int j = 0; #if __AVX512F__ for (; j + 15 < block_n; j += 16) @@ -2726,25 +2730,37 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } } + } + + // Phase B: Each head does softmax + O rescale + PV GEMM + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + Mat top_blob_head = top_blob.channel(q); + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* s_ptr = s_vec_thread.row(hq); + float* p_ptr = p_vec_thread.row(hq); float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; for (int i = 0; i < block_m; i++) { m_old[i] = m_vec[i]; } - softmax_tile_dispatch(p_vec_ptr, s_vec_ptr, m_vec, l_vec, block_m, block_n); + softmax_tile_dispatch(p_ptr, s_ptr, m_vec, l_vec, scale_factors, block_m, block_n); // Rescale O accumulator when max increases for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) { - float scale_factor = expf(m_old[i] - m_vec[i]); - vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factor, out_embed_dim); + vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factors[i], out_embed_dim); } } - pv_gemm_dispatch(top_blob_head.row(m_start), p_vec_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_dispatch(top_blob_head.row(m_start), p_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); } } From dc2de4b0a3af2eb3155d848a15342d83563085f3 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 10:31:46 +0800 Subject: [PATCH 12/53] Opt layout --- src/layer/x86/sdpa_x86.cpp | 294 +++++++++++++++++++++++++++++++++---- 1 file changed, 267 insertions(+), 27 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index c950f52d7025..58f7eb2d0012 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -615,6 +615,232 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl } } +// Explicit specialization for D=128: 6x4 kernel to improve K-tile reuse +template<> +void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + const float* k2 = K + (j + 2) * 128; + const float* k3 = K + (j + 3) * 128; + + __m512 acc[6][4]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 128; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + + __m512 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 128; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 128; + + __m512 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < 128; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 128; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + const float* kptr = K + j * 128; + for (int k = 0; k < 128; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 128; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + const float* k2 = K + (j + 2) * 128; + const float* k3 = K + (j + 3) * 128; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < 128; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 128; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + for (int k = 0; k < 128; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 128; + const float* kptr = K + j * 128; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < 128; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) @@ -2210,6 +2436,20 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, int m, int n, int d) { + if (d == 128) + { +#if __AVX512F__ + pv_gemm_avx512<2, 128>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 64>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(O, P, V, m, n, d); + return; +#endif + } + #if __AVX512F__ pv_gemm_avx512<4, 64>(O, P, V, m, n, d); #elif __AVX__ @@ -2610,8 +2850,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - Mat s_vec(BLOCK_N * BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - Mat p_vec(BLOCK_N * BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); if (s_vec.empty() || p_vec.empty()) return -100; @@ -2639,6 +2879,24 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); Mat p_vec_thread = p_vec.channel(get_omp_thread_num()); + // Pre-resolve mask pointers for all heads in this group + const float* mask_data[num_heads_per_group]; + int mask_stride[num_heads_per_group]; + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + mask_data[hq] = nullptr; + mask_stride[hq] = 0; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + Mat mh = (maskm.dims == 3 && maskm.c > 1) ? maskm.channel(q) + : (maskm.dims == 3 ? maskm.channel(0) : maskm); + mask_data[hq] = mh; + mask_stride[hq] = mh.w; + } + } + Mat m_state_tile = m_state.channel(idx); Mat l_state_tile = l_state.channel(idx); @@ -2668,12 +2926,15 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; - // Phase A: All heads compute QK GEMM while K tile is hot in cache + // Phase A + B fused per head: reuse compact s_vec/p_vec buffer for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; const Mat query_head = query.channel(q); - float* s_ptr = s_vec_thread.row(hq); + Mat top_blob_head = top_blob.channel(q); + + float* s_ptr = s_vec_thread.row(0); + float* p_ptr = p_vec_thread.row(0); qk_gemm_dispatch(s_ptr, query_head.row(m_start), @@ -2681,23 +2942,11 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to block_m, block_n, embed_dim, _scale); // Apply attention mask - if (attn_mask) + if (attn_mask && mask_data[hq]) { - Mat mask_head; - { - const Mat& maskm = attn_mask_blob; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - } for (int i = 0; i < block_m; i++) { - const float* mptr = mask_head.row(m_start + i) + n_start; + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; float* sptr = s_ptr + i * block_n; int j = 0; #if __AVX512F__ @@ -2730,18 +2979,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } } - } - - // Phase B: Each head does softmax + O rescale + PV GEMM - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - Mat top_blob_head = top_blob.channel(q); float* m_vec = m_state_tile.row(hq); float* l_vec = l_state_tile.row(hq); - float* s_ptr = s_vec_thread.row(hq); - float* p_ptr = p_vec_thread.row(hq); float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; From 9bfd7329952c81affab5e19494a9efef07d13e5b Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 10:49:36 +0800 Subject: [PATCH 13/53] opt for cache --- src/layer/x86/sdpa_x86.cpp | 542 ++++++++++++++++++++++--------------- 1 file changed, 328 insertions(+), 214 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 58f7eb2d0012..10e3af86723b 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1374,26 +1374,35 @@ static void qk_gemm_avx(float* S, const float* Q, const float* K, const float* k0 = K + (j + 0) * d; const float* k1 = K + (j + 1) * d; + __m256 acc[6][2]; for (int mi = 0; mi < 6; mi++) { - const float* qptr = Q + (i + mi) * d; - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } - int k = 0; - for (; k + 7 < d; k += 8) + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } + } - float sum0 = _mm256_reduce_add_ps(acc0); - float sum1 = _mm256_reduce_add_ps(acc1); + for (int mi = 0; mi < 6; mi++) + { + float sum0 = _mm256_reduce_add_ps(acc[mi][0]); + float sum1 = _mm256_reduce_add_ps(acc[mi][1]); for (; k < d; k++) { - float qv = qptr[k]; + float qv = Q[(i + mi) * d + k]; sum0 += qv * k0[k]; sum1 += qv * k1[k]; } @@ -1405,18 +1414,28 @@ static void qk_gemm_avx(float* S, const float* Q, const float* K, for (; j < n; j++) { + const float* kptr = K + j * d; + + __m256 acc[6]; for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) { - const float* qptr = Q + (i + mi) * d; - const float* kptr = K + j * d; - float sum = 0.f; - int k = 0; - __m256 vacc = _mm256_setzero_ps(); - for (; k + 7 < d; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - sum = _mm256_reduce_add_ps(vacc); + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum = _mm256_reduce_add_ps(acc[mi]); for (; k < d; k++) - sum += qptr[k] * kptr[k]; + sum += Q[(i + mi) * d + k] * kptr[k]; S[(i + mi) * n + j] = sum * scale; } } @@ -1487,35 +1506,53 @@ static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; + __m256 acc[6][2]; for (int mi = 0; mi < 6; mi++) { - const float* qptr = Q + (i + mi) * D; - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); - for (int k = 0; k < D; k += 8) + for (int mi = 0; mi < 6; mi++) { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } + } - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + for (int mi = 0; mi < 6; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; } } for (; j < n; j++) { + const float* kptr = K + j * D; + + __m256 acc[6]; for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) { - const float* qptr = Q + (i + mi) * D; - const float* kptr = K + j * D; - __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < D; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - S[(i + mi) * n + j] = _mm256_reduce_add_ps(vacc) * scale; + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } } + + for (int mi = 0; mi < 6; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; } } @@ -1559,116 +1596,110 @@ template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 8; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m256 acc[M_BLOCK][VEC_PER_UNROLL]; for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + + for (int j = 0; j < n; j++) { - __m256 acc[VEC_PER_UNROLL]; + __m256 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + vvec[vi] = _mm256_loadu_ps(V + j * d + dd + vi * 8); - for (int j = 0; j < n; j++) + for (int mi = 0; mi < M_BLOCK; mi++) { __m256 pvec = _mm256_set1_ps(pptr[mi][j]); for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + acc[mi][vi] = _mm256_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); } + } + for (int mi = 0; mi < M_BLOCK; mi++) for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); - } + _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[mi][vi]); } - for (; i < m; i++) + for (; dd + 7 < d; dd += 8) { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - - __m256 acc[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_loadu_ps(optr + vi * 8); + __m256 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_loadu_ps(op[mi] + dd); for (int j = 0; j < n; j++) { - __m256 pvec = _mm256_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + __m256 vvec = _mm256_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); } - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(optr + vi * 8, acc[vi]); + for (int mi = 0; mi < M_BLOCK; mi++) + _mm256_storeu_ps(op[mi] + dd, acc[mi]); } - } - for (; dd + 7 < d; dd += 8) - { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + for (; dd < d; dd++) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } - for (int mi = 0; mi < M_BLOCK; mi++) { - __m256 acc = _mm256_loadu_ps(op[mi]); + float acc = op[mi][dd]; for (int j = 0; j < n; j++) - acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), _mm256_loadu_ps(V + j * d + dd), acc); - _mm256_storeu_ps(op[mi], acc); + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; } } - - for (; i < m; i++) - { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - __m256 acc = _mm256_loadu_ps(optr); - for (int j = 0; j < n; j++) - acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); - _mm256_storeu_ps(optr, acc); - } } - for (; dd < d; dd++) + for (; i < m; i++) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } + __m256 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); + for (int j = 0; j < n; j++) { - for (int mi = 0; mi < M_BLOCK; mi++) - op[mi][0] += pptr[mi][j] * V[j * d + dd]; + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); } - for (; i < m; i++) + + for (; dd + 7 < d; dd += 8) { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; + __m256 acc = _mm256_loadu_ps(optr + dd); for (int j = 0; j < n; j++) - optr[0] += pptr[j] * V[j * d + dd]; + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; } } } @@ -1937,25 +1968,27 @@ static void qk_gemm_sse2(float* S, const float* Q, const float* K, { const float* kptr = K + j * d; + __m128 acc[4]; for (int mi = 0; mi < 4; mi++) - { - const float* qptr = Q + (i + mi) * d; - __m128 acc0 = _mm_setzero_ps(); + acc[mi] = _mm_setzero_ps(); - int k = 0; - for (; k + 3 < d; k += 4) + int k = 0; + for (; k + 3 < d; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) { - acc0 = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), acc0); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } + } - float sum0 = _mm_reduce_add_ps(acc0); - + for (int mi = 0; mi < 4; mi++) + { + float sum = _mm_reduce_add_ps(acc[mi]); for (; k < d; k++) - { - sum0 += qptr[k] * kptr[k]; - } - - S[(i + mi) * n + j] = sum0 * scale; + sum += Q[(i + mi) * d + k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; } } } @@ -2044,116 +2077,110 @@ template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 4; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m128 acc[M_BLOCK][VEC_PER_UNROLL]; for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm_loadu_ps(op[mi] + dd + vi * 4); + + for (int j = 0; j < n; j++) { - __m128 acc[VEC_PER_UNROLL]; + __m128 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_loadu_ps(op[mi] + vi * 4); + vvec[vi] = _mm_loadu_ps(V + j * d + dd + vi * 4); - for (int j = 0; j < n; j++) + for (int mi = 0; mi < M_BLOCK; mi++) { __m128 pvec = _mm_set1_ps(pptr[mi][j]); for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); + acc[mi][vi] = _mm_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); } + } + for (int mi = 0; mi < M_BLOCK; mi++) for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm_storeu_ps(op[mi] + vi * 4, acc[vi]); - } + _mm_storeu_ps(op[mi] + dd + vi * 4, acc[mi][vi]); } - for (; i < m; i++) + for (; dd + 3 < d; dd += 4) { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - - __m128 acc[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_loadu_ps(optr + vi * 4); + __m128 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm_loadu_ps(op[mi] + dd); for (int j = 0; j < n; j++) { - __m128 pvec = _mm_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); + __m128 vvec = _mm_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), vvec, acc[mi]); } - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm_storeu_ps(optr + vi * 4, acc[vi]); + for (int mi = 0; mi < M_BLOCK; mi++) + _mm_storeu_ps(op[mi] + dd, acc[mi]); } - } - for (; dd + 3 < d; dd += 4) - { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + for (; dd < d; dd++) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; for (int mi = 0; mi < M_BLOCK; mi++) { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } - - for (int mi = 0; mi < M_BLOCK; mi++) - { - __m128 acc = _mm_loadu_ps(op[mi]); + float acc = op[mi][dd]; for (int j = 0; j < n; j++) - acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[mi][j]), _mm_loadu_ps(V + j * d + dd), acc); - _mm_storeu_ps(op[mi], acc); + acc += pptr[mi][j] * V[j * d + dd]; + op[mi][dd] = acc; } } - - for (; i < m; i++) - { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - __m128 acc = _mm_loadu_ps(optr); - for (int j = 0; j < n; j++) - acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), _mm_loadu_ps(V + j * d + dd), acc); - _mm_storeu_ps(optr, acc); - } } - for (; dd < d; dd++) + for (; i < m; i++) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + float* optr = O + i * d; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } + __m128 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_loadu_ps(optr + dd + vi * 4); + for (int j = 0; j < n; j++) { - for (int mi = 0; mi < M_BLOCK; mi++) - op[mi][0] += pptr[mi][j] * V[j * d + dd]; + __m128 pvec = _mm_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm_comp_fmadd_ps(pvec, _mm_loadu_ps(V + j * d + dd + vi * 4), acc[vi]); } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm_storeu_ps(optr + dd + vi * 4, acc[vi]); } - for (; i < m; i++) + + for (; dd + 3 < d; dd += 4) { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; + __m128 acc = _mm_loadu_ps(optr + dd); for (int j = 0; j < n; j++) - optr[0] += pptr[j] * V[j * d + dd]; + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), _mm_loadu_ps(V + j * d + dd), acc); + _mm_storeu_ps(optr + dd, acc); + } + + for (; dd < d; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; } } } @@ -2419,6 +2446,71 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, #elif __SSE2__ qk_gemm_specialized_sse2<512>(S, Q, K, m, n, scale); return; +#endif + } + if (d == 256) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<256>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 32) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<32>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 80) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<80>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 96) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<96>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 160) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<160>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(S, Q, K, m, n, scale); + return; #endif } @@ -2442,10 +2534,36 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 128>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, d); + pv_gemm_avx<2, 32>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + return; +#endif + } + if (d == 64) + { +#if __AVX512F__ + pv_gemm_avx512<4, 64>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + return; +#endif + } + if (d == 256) + { +#if __AVX512F__ + pv_gemm_avx512<2, 256>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); return; #elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, d); + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); return; #endif } @@ -2453,9 +2571,9 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, #if __AVX512F__ pv_gemm_avx512<4, 64>(O, P, V, m, n, d); #elif __AVX__ - pv_gemm_avx<4, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 32>(O, P, V, m, n, d); #elif __SSE2__ - pv_gemm_sse2<4, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); #else pv_gemm_scalar(O, P, V, m, n, d); #endif @@ -2851,9 +2969,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat o_accum(out_embed_dim, BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - if (s_vec.empty() || p_vec.empty()) + if (s_vec.empty() || o_accum.empty()) return -100; int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; @@ -2877,7 +2995,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const Mat value_head = value.channel(g); Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); - Mat p_vec_thread = p_vec.channel(get_omp_thread_num()); + Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); // Pre-resolve mask pointers for all heads in this group const float* mask_data[num_heads_per_group]; @@ -2903,9 +3021,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // Initialize softmax state and zero output accumulator for all Q heads in this group for (int hq = 0; hq < num_heads_per_group; hq++) { - int q = g * num_heads_per_group + hq; - Mat top_blob_head = top_blob.channel(q); - float* m_vec = m_state_tile.row(hq); float* l_vec = l_state_tile.row(hq); for (int i = 0; i < block_m; i++) @@ -2914,10 +3029,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to l_vec[i] = 0.f; } - for (int i = 0; i < block_m; i++) - { - vec_zero_dispatch(top_blob_head.row(m_start + i), out_embed_dim); - } + float* o_ptr = o_accum_thread.channel(hq).row(0); + vec_zero_dispatch(o_ptr, out_embed_dim * BLOCK_M); } // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group @@ -2934,7 +3047,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat top_blob_head = top_blob.channel(q); float* s_ptr = s_vec_thread.row(0); - float* p_ptr = p_vec_thread.row(0); qk_gemm_dispatch(s_ptr, query_head.row(m_start), @@ -2982,6 +3094,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* m_vec = m_state_tile.row(hq); float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.channel(hq).row(0); float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; @@ -2989,27 +3102,28 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { m_old[i] = m_vec[i]; } - softmax_tile_dispatch(p_ptr, s_ptr, m_vec, l_vec, scale_factors, block_m, block_n); + softmax_tile_dispatch(s_ptr, s_ptr, m_vec, l_vec, scale_factors, block_m, block_n); // Rescale O accumulator when max increases for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) { - vec_scale_dispatch(top_blob_head.row(m_start + i), scale_factors[i], out_embed_dim); + vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } - pv_gemm_dispatch(top_blob_head.row(m_start), p_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_dispatch(o_ptr, s_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); } } - // Normalize all Q heads for this M tile + // Normalize all Q heads for this M tile and write back to top_blob for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; Mat top_blob_head = top_blob.channel(q); float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.channel(hq).row(0); for (int i = 0; i < block_m; i++) { @@ -3020,30 +3134,30 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to __m512 vinv_l = _mm512_set1_ps(inv_l); for (; k + 15 < out_embed_dim; k += 16) { - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(outptr + k), vinv_l)); + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); } if (k < out_embed_dim) { __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, outptr + k), vinv_l)); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, o_ptr + i * out_embed_dim + k), vinv_l)); k = out_embed_dim; } #elif __AVX__ __m256 vinv_l = _mm256_set1_ps(inv_l); for (; k + 7 < out_embed_dim; k += 8) { - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(outptr + k), vinv_l)); + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); } #elif __SSE2__ __m128 vinv_l = _mm_set1_ps(inv_l); for (; k + 3 < out_embed_dim; k += 4) { - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(outptr + k), vinv_l)); + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); } #endif for (; k < out_embed_dim; k++) { - outptr[k] *= inv_l; + outptr[k] = o_ptr[i * out_embed_dim + k] * inv_l; } } } From 22b6dceb78e632f49bc22291a071253f27de31e4 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 26 Apr 2026 23:53:23 +0800 Subject: [PATCH 14/53] opt for lagre dim --- src/layer/x86/sdpa_x86.cpp | 3764 +++++++++++++++++++++++++++--------- 1 file changed, 2876 insertions(+), 888 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 10e3af86723b..4adc6ab940ba 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -841,231 +841,2031 @@ void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, } } - -template -static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) +template<> +void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) { - const int VEC_PER_UNROLL = D_UNROLL / 16; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + int i = 0; + for (; i + 4 <= m; i += 4) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + int j = 0; + for (; j + 4 <= n; j += 4) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m512 acc[4][4]; + for (int mi = 0; mi < 4; mi++) { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); } - __m512 acc[M_BLOCK][VEC_PER_UNROLL]; - for (int mi = 0; mi < M_BLOCK; mi++) - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm512_loadu_ps(op[mi] + vi * 16); - - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 16) { - __m512 vvec[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); - for (int mi = 0; mi < M_BLOCK; mi++) + for (int mi = 0; mi < 4; mi++) { - __m512 pvec = _mm512_set1_ps(pptr[mi][j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm512_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); } } - for (int mi = 0; mi < M_BLOCK; mi++) - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm512_storeu_ps(op[mi] + vi * 16, acc[mi][vi]); + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } } - for (; i < m; i++) + for (; j + 2 <= n; j += 2) { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; - __m512 acc[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm512_loadu_ps(optr + vi * 16); + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 16) { - __m512 pvec = _mm512_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } } - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm512_storeu_ps(optr + vi * 16, acc[vi]); + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } } - } - for (; dd + 15 < d; dd += 16) - { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + for (; j < n; j++) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; - } + const float* kptr = K + j * 1024; - __m512 acc[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_loadu_ps(op[mi]); + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 16) { - __m512 vvec = _mm512_loadu_ps(V + j * d + dd); - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } } - for (int mi = 0; mi < M_BLOCK; mi++) - _mm512_storeu_ps(op[mi], acc[mi]); - } - - for (; i < m; i++) - { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - __m512 acc = _mm512_loadu_ps(optr); - for (int j = 0; j < n; j++) - acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); - _mm512_storeu_ps(optr, acc); + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } - for (; dd < d; dd++) + for (; i + 2 <= m; i += 2) { - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + int j = 0; + for (; j + 4 <= n; j += 4) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m512 acc[2][4]; + for (int mi = 0; mi < 2; mi++) { - op[mi] = O + (i + mi) * d + dd; - pptr[mi] = P + (i + mi) * n; + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); } - for (int j = 0; j < n; j++) + + for (int k = 0; k < 1024; k += 16) { - for (int mi = 0; mi < M_BLOCK; mi++) - op[mi][0] += pptr[mi][j] * V[j * d + dd]; - } - } - for (; i < m; i++) - { - float* optr = O + i * d + dd; - const float* pptr = P + i * n; - for (int j = 0; j < n; j++) - optr[0] += pptr[j] * V[j * d + dd]; - } - } -} + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } -template -static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) -{ - const int VEC_PER_D = D / 16; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) - { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * D; - pptr[mi] = P + (i + mi) * n; + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } } - int dd = 0; - for (; dd + 127 < D; dd += 128) + for (; j + 2 <= n; j += 2) { - __m512 acc[M_BLOCK][8]; - for (int mi = 0; mi < M_BLOCK; mi++) + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m512 acc[2][2]; + for (int mi = 0; mi < 2; mi++) { - acc[mi][0] = _mm512_loadu_ps(op[mi] + dd + 0 * 16); - acc[mi][1] = _mm512_loadu_ps(op[mi] + dd + 1 * 16); - acc[mi][2] = _mm512_loadu_ps(op[mi] + dd + 2 * 16); - acc[mi][3] = _mm512_loadu_ps(op[mi] + dd + 3 * 16); - acc[mi][4] = _mm512_loadu_ps(op[mi] + dd + 4 * 16); - acc[mi][5] = _mm512_loadu_ps(op[mi] + dd + 5 * 16); - acc[mi][6] = _mm512_loadu_ps(op[mi] + dd + 6 * 16); - acc[mi][7] = _mm512_loadu_ps(op[mi] + dd + 7 * 16); + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); } - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 16) { - __m512 v0 = _mm512_loadu_ps(V + j * D + dd + 0 * 16); - __m512 v1 = _mm512_loadu_ps(V + j * D + dd + 1 * 16); - __m512 v2 = _mm512_loadu_ps(V + j * D + dd + 2 * 16); - __m512 v3 = _mm512_loadu_ps(V + j * D + dd + 3 * 16); - __m512 v4 = _mm512_loadu_ps(V + j * D + dd + 4 * 16); - __m512 v5 = _mm512_loadu_ps(V + j * D + dd + 5 * 16); - __m512 v6 = _mm512_loadu_ps(V + j * D + dd + 6 * 16); - __m512 v7 = _mm512_loadu_ps(V + j * D + dd + 7 * 16); + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (int mi = 0; mi < M_BLOCK; mi++) + for (int mi = 0; mi < 2; mi++) { - __m512 pvec = _mm512_set1_ps(pptr[mi][j]); - acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); - acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); - acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); - acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); - acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < M_BLOCK; mi++) + for (int mi = 0; mi < 2; mi++) { - _mm512_storeu_ps(op[mi] + dd + 0 * 16, acc[mi][0]); - _mm512_storeu_ps(op[mi] + dd + 1 * 16, acc[mi][1]); - _mm512_storeu_ps(op[mi] + dd + 2 * 16, acc[mi][2]); - _mm512_storeu_ps(op[mi] + dd + 3 * 16, acc[mi][3]); - _mm512_storeu_ps(op[mi] + dd + 4 * 16, acc[mi][4]); - _mm512_storeu_ps(op[mi] + dd + 5 * 16, acc[mi][5]); - _mm512_storeu_ps(op[mi] + dd + 6 * 16, acc[mi][6]); - _mm512_storeu_ps(op[mi] + dd + 7 * 16, acc[mi][7]); + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; } } - for (; dd + 15 < D; dd += 16) + for (; j < n; j++) { - __m512 acc[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_loadu_ps(op[mi] + dd); + const float* kptr = K + j * 1024; - for (int j = 0; j < n; j++) + __m512 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < 1024; k += 16) { - __m512 vvec = _mm512_loadu_ps(V + j * D + dd); - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } } + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < 1024; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + for (int k = 0; k < 1024; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 1024; + const float* kptr = K + j * 1024; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < 1024; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; + + __m512 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 2048; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + + __m512 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 2048; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 2048; + + __m512 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < 2048; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < 2048; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + for (int k = 0; k < 2048; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 2048; + const float* kptr = K + j * 2048; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < 2048; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; + + __m512 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 4096; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + + __m512 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + for (int k = 0; k < 4096; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 4096; + + __m512 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < 4096; k += 16) + { + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + for (int k = 0; k < 4096; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + + for (int k = 0; k < 4096; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 4096; + const float* kptr = K + j * 4096; + __m512 vacc = _mm512_setzero_ps(); + for (int k = 0; k < 4096; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} + + +template +static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 16; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm512_loadu_ps(op[mi] + vi * 16); + + for (int j = 0; j < n; j++) + { + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm512_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(op[mi] + vi * 16, acc[mi][vi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + + __m512 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_loadu_ps(optr + vi * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm512_storeu_ps(optr + vi * 16, acc[vi]); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + for (int j = 0; j < n; j++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] += pptr[mi][j] * V[j * d + dd]; + } + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + optr[0] += pptr[j] * V[j * d + dd]; + } + } +} + + +template +static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 16; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc[M_BLOCK][8]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + dd + 0 * 16); + acc[mi][1] = _mm512_loadu_ps(op[mi] + dd + 1 * 16); + acc[mi][2] = _mm512_loadu_ps(op[mi] + dd + 2 * 16); + acc[mi][3] = _mm512_loadu_ps(op[mi] + dd + 3 * 16); + acc[mi][4] = _mm512_loadu_ps(op[mi] + dd + 4 * 16); + acc[mi][5] = _mm512_loadu_ps(op[mi] + dd + 5 * 16); + acc[mi][6] = _mm512_loadu_ps(op[mi] + dd + 6 * 16); + acc[mi][7] = _mm512_loadu_ps(op[mi] + dd + 7 * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 v0 = _mm512_loadu_ps(V + j * D + dd + 0 * 16); + __m512 v1 = _mm512_loadu_ps(V + j * D + dd + 1 * 16); + __m512 v2 = _mm512_loadu_ps(V + j * D + dd + 2 * 16); + __m512 v3 = _mm512_loadu_ps(V + j * D + dd + 3 * 16); + __m512 v4 = _mm512_loadu_ps(V + j * D + dd + 4 * 16); + __m512 v5 = _mm512_loadu_ps(V + j * D + dd + 5 * 16); + __m512 v6 = _mm512_loadu_ps(V + j * D + dd + 6 * 16); + __m512 v7 = _mm512_loadu_ps(V + j * D + dd + 7 * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); + acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); + acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); + acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); + acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + _mm512_storeu_ps(op[mi] + dd + 0 * 16, acc[mi][0]); + _mm512_storeu_ps(op[mi] + dd + 1 * 16, acc[mi][1]); + _mm512_storeu_ps(op[mi] + dd + 2 * 16, acc[mi][2]); + _mm512_storeu_ps(op[mi] + dd + 3 * 16, acc[mi][3]); + _mm512_storeu_ps(op[mi] + dd + 4 * 16, acc[mi][4]); + _mm512_storeu_ps(op[mi] + dd + 5 * 16, acc[mi][5]); + _mm512_storeu_ps(op[mi] + dd + 6 * 16, acc[mi][6]); + _mm512_storeu_ps(op[mi] + dd + 7 * 16, acc[mi][7]); + } + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m512 vvec = _mm512_loadu_ps(V + j * D + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm512_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < D; dd++) + { + for (int mi = 0; mi < M_BLOCK; mi++) + { + float acc = op[mi][dd]; + for (int j = 0; j < n; j++) + acc += pptr[mi][j] * V[j * D + dd]; + op[mi][dd] = acc; + } + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + int dd = 0; + for (; dd + 127 < D; dd += 128) + { + __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); + acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); + acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); + acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); + acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); + acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); + acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); + acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); + } + + _mm512_storeu_ps(optr + dd + 0 * 16, acc0); + _mm512_storeu_ps(optr + dd + 1 * 16, acc1); + _mm512_storeu_ps(optr + dd + 2 * 16, acc2); + _mm512_storeu_ps(optr + dd + 3 * 16, acc3); + _mm512_storeu_ps(optr + dd + 4 * 16, acc4); + _mm512_storeu_ps(optr + dd + 5 * 16, acc5); + _mm512_storeu_ps(optr + dd + 6 * 16, acc6); + _mm512_storeu_ps(optr + dd + 7 * 16, acc7); + } + + for (; dd + 15 < D; dd += 16) + { + __m512 acc = _mm512_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); + _mm512_storeu_ps(optr + dd, acc); + } + + for (; dd < D; dd++) + { + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * D + dd]; + optr[dd] = acc; + } + } +} + + +static inline void softmax_tile_avx512(float* P, const float* S, + float* m_vec, float* l_vec, float* scale_out, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); + vmax = _mm512_max_ps(vmax, tail); + } + float m_new = _mm512_comp_reduce_max_ps(vmax); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + l_vec[i] *= scale_factor; + + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < n; j += 16) + { + __m512 svec = _mm512_loadu_ps(sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_storeu_ps(pptr + j, evec); + vsum = _mm512_add_ps(vsum, evec); + } + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_mask_storeu_ps(pptr + j, mask, evec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); + } + float l_add = _mm512_comp_reduce_add_ps(vsum); + l_vec[i] += l_add; + m_vec[i] = m_new; + } +} + +static inline void vec_scale_avx512(float* x, float s, int n) +{ + __m512 vscale = _mm512_set1_ps(s); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); + } +} + +static inline void vec_zero_avx512(float* x, int n) +{ + __m512 zero = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, zero); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, zero); + } +} + +static inline void decode_qk_dot_avx512(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +} + +static inline void decode_pv_gemv_avx512(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + __m512 pvec = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 15 < out_d; k += 16) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + } + } +} + +static void sdpa_decode_avx512(float* out, const float* q, + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; + __attribute__((aligned(64))) float s[BLOCK_N]; + + vec_zero_avx512(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + + decode_qk_dot_avx512(s, q, K, n_start, block_n, d, scale); + + if (mask) + { + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + } + } + + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_avx512(out, scale_factor, out_d); + } + + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); + + decode_pv_gemv_avx512(out, s, V, n_start, block_n, out_d); + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale_avx512(out, inv_l, out_d); +} + +#endif // __AVX512F__ + +#if __AVX__ + +static void qk_gemm_avx(float* S, const float* Q, const float* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum0 = _mm256_reduce_add_ps(acc[mi][0]); + float sum1 = _mm256_reduce_add_ps(acc[mi][1]); + + for (; k < d; k++) + { + float qv = Q[(i + mi) * d + k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * d; + + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum = _mm256_reduce_add_ps(acc[mi]); + for (; k < d; k++) + sum += Q[(i + mi) * d + k] * kptr[k]; + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const float* k0 = K + (j + 0) * d; + const float* k1 = K + (j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * k0[k]; + sum1 += qv * k1[k]; + } + + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} + + +template +static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * D; + + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < D; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < D; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m256 acc[4][4]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 1024; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m256 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 1024; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 1024; + + __m256 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < 1024; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m256 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 1024; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m256 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 1024; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 1024; + + __m256 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < 1024; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int k = 0; k < 1024; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); + acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < 1024; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 1024; + const float* kptr = K + j * 1024; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < 1024; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; + + __m256 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 2048; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + + __m256 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 2048; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 2048; + + __m256 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < 2048; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int k = 0; k < 2048; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); + acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < 2048; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 2048; + const float* kptr = K + j * 2048; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < 2048; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + +template<> +void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; + + __m256 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 4096; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; + } + } + + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + + __m256 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + for (int k = 0; k < 4096; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const float* kptr = K + j * 4096; + + __m256 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_setzero_ps(); + + for (int k = 0; k < 4096; k += 8) + { + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + __m256 acc2 = _mm256_setzero_ps(); + __m256 acc3 = _mm256_setzero_ps(); + + for (int k = 0; k < 4096; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); + acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + for (int k = 0; k < 4096; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * 4096; + const float* kptr = K + j * 4096; + __m256 vacc = _mm256_setzero_ps(); + for (int k = 0; k < 4096; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + } + } +} + + +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) +{ + const int VEC_PER_UNROLL = D_UNROLL / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * d; + pptr[mi] = P + (i + mi) * n; + } + + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + __m256 acc[M_BLOCK][VEC_PER_UNROLL]; for (int mi = 0; mi < M_BLOCK; mi++) - _mm512_storeu_ps(op[mi] + dd, acc[mi]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm256_loadu_ps(V + j * d + dd + vi * 8); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[mi][vi] = _mm256_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[mi][vi]); } - for (; dd < D; dd++) + for (; dd + 7 < d; dd += 8) + { + __m256 acc[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_loadu_ps(op[mi] + dd); + + for (int j = 0; j < n; j++) + { + __m256 vvec = _mm256_loadu_ps(V + j * d + dd); + for (int mi = 0; mi < M_BLOCK; mi++) + acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + + for (int mi = 0; mi < M_BLOCK; mi++) + _mm256_storeu_ps(op[mi] + dd, acc[mi]); + } + + for (; dd < d; dd++) { for (int mi = 0; mi < M_BLOCK; mi++) { float acc = op[mi][dd]; for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * D + dd]; + acc += pptr[mi][j] * V[j * d + dd]; op[mi][dd] = acc; } } @@ -1073,64 +2873,102 @@ static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int for (; i < m; i++) { - float* optr = O + i * D; + float* optr = O + i * d; const float* pptr = P + i * n; int dd = 0; - for (; dd + 127 < D; dd += 128) + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { - __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); - __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); - __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); - __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); - __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); - __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); - __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); - __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); + __m256 acc[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); for (int j = 0; j < n; j++) { - __m512 pvec = _mm512_set1_ps(pptr[j]); - acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); - acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); - acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); - acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); - acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); - acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); - acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); - acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); } - _mm512_storeu_ps(optr + dd + 0 * 16, acc0); - _mm512_storeu_ps(optr + dd + 1 * 16, acc1); - _mm512_storeu_ps(optr + dd + 2 * 16, acc2); - _mm512_storeu_ps(optr + dd + 3 * 16, acc3); - _mm512_storeu_ps(optr + dd + 4 * 16, acc4); - _mm512_storeu_ps(optr + dd + 5 * 16, acc5); - _mm512_storeu_ps(optr + dd + 6 * 16, acc6); - _mm512_storeu_ps(optr + dd + 7 * 16, acc7); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); } - for (; dd + 15 < D; dd += 16) + for (; dd + 7 < d; dd += 8) { - __m512 acc = _mm512_loadu_ps(optr + dd); + __m256 acc = _mm256_loadu_ps(optr + dd); for (int j = 0; j < n; j++) - acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); - _mm512_storeu_ps(optr + dd, acc); + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); + _mm256_storeu_ps(optr + dd, acc); } - for (; dd < D; dd++) + for (; dd < d; dd++) { float acc = optr[dd]; for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * D + dd]; + acc += pptr[j] * V[j * d + dd]; optr[dd] = acc; } } } -static inline void softmax_tile_avx512(float* P, const float* S, +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(optr + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(optr + vi * 8, acc[vi]); + } +} + + +static inline void softmax_tile_avx(float* P, const float* S, float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) @@ -1138,157 +2976,130 @@ static inline void softmax_tile_avx512(float* P, const float* S, const float* sptr = S + i * n; float* pptr = P + i * n; - __m512 vmax = _mm512_set1_ps(m_vec[i]); + __m256 vmax = _mm256_set1_ps(m_vec[i]); int j = 0; - for (; j + 15 < n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); - if (j < n) - { - __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); - __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); - vmax = _mm512_max_ps(vmax, tail); - } - float m_new = _mm512_comp_reduce_max_ps(vmax); + for (; j + 7 < n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; l_vec[i] *= scale_factor; - __m512 vm_new = _mm512_set1_ps(m_new); - __m512 vsum = _mm512_setzero_ps(); + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); j = 0; - for (; j + 15 < n; j += 16) + for (; j + 7 < n; j += 8) { - __m512 svec = _mm512_loadu_ps(sptr + j); - __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); - _mm512_storeu_ps(pptr + j, evec); - vsum = _mm512_add_ps(vsum, evec); + __m256 svec = _mm256_loadu_ps(sptr + j); + __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); + _mm256_storeu_ps(pptr + j, evec); + vsum = _mm256_add_ps(vsum, evec); } - if (j < n) + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < n; j++) { - __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); - __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); - _mm512_mask_storeu_ps(pptr + j, mask, evec); - vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; } - float l_add = _mm512_comp_reduce_add_ps(vsum); l_vec[i] += l_add; m_vec[i] = m_new; } } -static inline void vec_scale_avx512(float* x, float s, int n) +static inline void vec_scale_avx(float* x, float s, int n) { - __m512 vscale = _mm512_set1_ps(s); + __m256 vscale = _mm256_set1_ps(s); int i = 0; - for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); - if (i < n) - { - __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); - } + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; } -static inline void vec_zero_avx512(float* x, int n) +static inline void vec_zero_avx(float* x, int n) { - __m512 zero = _mm512_setzero_ps(); + __m256 zero = _mm256_setzero_ps(); int i = 0; - for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, zero); - if (i < n) - { - __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, zero); - } + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; } -static inline void decode_qk_dot_avx512(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) +static inline void decode_qk_dot_avx(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) { int j = 0; - for (; j + 3 < block_n; j += 4) + for (; j + 1 < block_n; j += 2) { const float* k0 = K + (n_start + j + 0) * d; const float* k1 = K + (n_start + j + 1) * d; - const float* k2 = K + (n_start + j + 2) * d; - const float* k3 = K + (n_start + j + 3) * d; - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); int k = 0; - for (; k + 15 < d; k += 16) + for (; k + 7 < d; k += 8) { - __m512 qv = _mm512_loadu_ps(q + k); - acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); } - if (k < d) + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) { - __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); - __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); - acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + sum0 += q[k] * k0[k]; + sum1 += q[k] * k1[k]; } - s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; } for (; j < block_n; j++) { - __m512 acc = _mm512_setzero_ps(); + const float* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); int k = 0; - for (; k + 15 < d; k += 16) - acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); - if (k < d) - { - __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); - acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); - } - s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * kptr[k]; + s[j] = sum * scale; } } -static inline void decode_pv_gemv_avx512(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) +static inline void decode_pv_gemv_avx(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) { for (int j = 0; j < block_n; j++) { - __m512 pvec = _mm512_set1_ps(s[j]); + __m256 pvec = _mm256_set1_ps(s[j]); int k = 0; - for (; k + 15 < out_d; k += 16) - { - __m512 oval = _mm512_loadu_ps(out + k); - __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); - _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); - } - if (k < out_d) + for (; k + 7 < out_d; k += 8) { - __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); - __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); - __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); - _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); } + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; } } -static void sdpa_decode_avx512(float* out, const float* q, +static void sdpa_decode_avx(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) { const int BLOCK_N = 128; - __attribute__((aligned(64))) float s[BLOCK_N]; + __attribute__((aligned(32))) float s[BLOCK_N]; - vec_zero_avx512(out, out_d); + vec_zero_avx(out, out_d); float m = -FLT_MAX; float l = 0.f; @@ -1297,143 +3108,92 @@ static void sdpa_decode_avx512(float* out, const float* q, { int block_n = std::min(BLOCK_N, n - n_start); - decode_qk_dot_avx512(s, q, K, n_start, block_n, d, scale); + decode_qk_dot_avx(s, q, K, n_start, block_n, d, scale); if (mask) { int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); - } + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; } - __m512 vmax = _mm512_set1_ps(-FLT_MAX); + __m256 vmax = _mm256_set1_ps(-FLT_MAX); int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); - } - float tile_m = _mm512_comp_reduce_max_ps(vmax); + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); float new_m = std::max(m, tile_m); if (m != new_m) { float scale_factor = expf(m - new_m); l *= scale_factor; - vec_scale_avx512(out, scale_factor, out_d); + vec_scale_avx(out, scale_factor, out_d); } - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); j = 0; - for (; j + 15 < block_n; j += 16) + for (; j + 7 < block_n; j += 8) { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); } - if (j < block_n) + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + s[j] = expf(s[j] - new_m); + l_add += s[j]; } - l += _mm512_comp_reduce_add_ps(vsum); + l += l_add; - decode_pv_gemv_avx512(out, s, V, n_start, block_n, out_d); + decode_pv_gemv_avx(out, s, V, n_start, block_n, out_d); m = new_m; } float inv_l = 1.f / l; - vec_scale_avx512(out, inv_l, out_d); + vec_scale_avx(out, inv_l, out_d); } -#endif // __AVX512F__ +#endif // __AVX__ -#if __AVX__ +#if __SSE2__ -static void qk_gemm_avx(float* S, const float* Q, const float* K, +static void qk_gemm_sse2(float* S, const float* Q, const float* K, int m, int n, int d, float scale) { int i = 0; - for (; i + 6 <= m; i += 6) + for (; i + 4 <= m; i += 4) { int j = 0; - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * d; - const float* k1 = K + (j + 1) * d; - - __m256 acc[6][2]; - for (int mi = 0; mi < 6; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - } - - int k = 0; - for (; k + 7 < d; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - - for (int mi = 0; mi < 6; mi++) - { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 6; mi++) - { - float sum0 = _mm256_reduce_add_ps(acc[mi][0]); - float sum1 = _mm256_reduce_add_ps(acc[mi][1]); - - for (; k < d; k++) - { - float qv = Q[(i + mi) * d + k]; - sum0 += qv * k0[k]; - sum1 += qv * k1[k]; - } - - S[(i + mi) * n + j + 0] = sum0 * scale; - S[(i + mi) * n + j + 1] = sum1 * scale; - } - } - for (; j < n; j++) { const float* kptr = K + j * d; - __m256 acc[6]; - for (int mi = 0; mi < 6; mi++) - acc[mi] = _mm256_setzero_ps(); + __m128 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm_setzero_ps(); int k = 0; - for (; k + 7 < d; k += 8) + for (; k + 3 < d; k += 4) { - __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 6; mi++) + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); - acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 6; mi++) + for (int mi = 0; mi < 4; mi++) { - float sum = _mm256_reduce_add_ps(acc[mi]); + float sum = _mm_reduce_add_ps(acc[mi]); for (; k < d; k++) sum += Q[(i + mi) * d + k] * kptr[k]; S[(i + mi) * n + j] = sum * scale; @@ -1443,629 +3203,681 @@ static void qk_gemm_avx(float* S, const float* Q, const float* K, for (; i < m; i++) { - int j = 0; - for (; j + 1 < n; j += 2) + for (int j = 0; j < n; j++) { const float* qptr = Q + i * d; - const float* k0 = K + (j + 0) * d; - const float* k1 = K + (j + 1) * d; + const float* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m128 vacc = _mm_setzero_ps(); + for (; k + 3 < d; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + sum = _mm_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * kptr[k]; + S[i * n + j] = sum * scale; + } + } +} - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - int k = 0; - for (; k + 7 < d; k += 8) - { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); - } +template +static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 4 <= m; i += 4) + { + for (int j = 0; j < n; j++) + { + const float* kptr = K + j * D; - float sum0 = _mm256_reduce_add_ps(acc0); - float sum1 = _mm256_reduce_add_ps(acc1); + const float* q0 = Q + (i + 0) * D; + const float* q1 = Q + (i + 1) * D; + const float* q2 = Q + (i + 2) * D; + const float* q3 = Q + (i + 3) * D; - for (; k < d; k++) + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); + + for (int k = 0; k < D; k += 4) { - float qv = qptr[k]; - sum0 += qv * k0[k]; - sum1 += qv * k1[k]; + __m128 kvec = _mm_loadu_ps(kptr + k); + + __m128 qvec = _mm_loadu_ps(q0 + k); + acc0 = _mm_comp_fmadd_ps(qvec, kvec, acc0); + + qvec = _mm_loadu_ps(q1 + k); + acc1 = _mm_comp_fmadd_ps(qvec, kvec, acc1); + + qvec = _mm_loadu_ps(q2 + k); + acc2 = _mm_comp_fmadd_ps(qvec, kvec, acc2); + + qvec = _mm_loadu_ps(q3 + k); + acc3 = _mm_comp_fmadd_ps(qvec, kvec, acc3); } - S[i * n + j + 0] = sum0 * scale; - S[i * n + j + 1] = sum1 * scale; + S[(i + 0) * n + j] = _mm_reduce_add_ps(acc0) * scale; + S[(i + 1) * n + j] = _mm_reduce_add_ps(acc1) * scale; + S[(i + 2) * n + j] = _mm_reduce_add_ps(acc2) * scale; + S[(i + 3) * n + j] = _mm_reduce_add_ps(acc3) * scale; } + } - for (; j < n; j++) + for (; i < m; i++) + { + for (int j = 0; j < n; j++) { - const float* qptr = Q + i * d; - const float* kptr = K + j * d; - float sum = 0.f; - int k = 0; - __m256 vacc = _mm256_setzero_ps(); - for (; k + 7 < d; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - sum = _mm256_reduce_add_ps(vacc); - for (; k < d; k++) - sum += qptr[k] * kptr[k]; - S[i * n + j] = sum * scale; + const float* qptr = Q + i * D; + const float* kptr = K + j * D; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < D; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } } } - -template -static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float* K, +template<> +void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; - for (; i + 6 <= m; i += 6) + for (; i + 4 <= m; i += 4) { int j = 0; - for (; j + 2 <= n; j += 2) + for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * D; - const float* k1 = K + (j + 1) * D; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; - __m256 acc[6][2]; - for (int mi = 0; mi < 6; mi++) + __m128 acc[4][4]; + for (int mi = 0; mi < 4; mi++) { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); } - for (int k = 0; k < D; k += 8) + for (int k = 0; k < 1024; k += 4) { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); - for (int mi = 0; mi < 6; mi++) + for (int mi = 0; mi < 4; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); } } - for (int mi = 0; mi < 6; mi++) + for (int mi = 0; mi < 4; mi++) { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; } } - for (; j < n; j++) + for (; j + 2 <= n; j += 2) { - const float* kptr = K + j * D; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; - __m256 acc[6]; - for (int mi = 0; mi < 6; mi++) - acc[mi] = _mm256_setzero_ps(); + __m128 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + } - for (int k = 0; k < D; k += 8) + for (int k = 0; k < 1024; k += 4) { - __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 6; mi++) + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + + for (int mi = 0; mi < 4; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); - acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < 6; mi++) - S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + } } - } - for (; i < m; i++) - { - int j = 0; - for (; j + 1 < n; j += 2) + for (; j < n; j++) { - const float* qptr = Q + i * D; - const float* k0 = K + (j + 0) * D; - const float* k1 = K + (j + 1) * D; + const float* kptr = K + j * 1024; - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); + __m128 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm_setzero_ps(); - for (int k = 0; k < D; k += 8) + for (int k = 0; k < 1024; k += 4) { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } } - S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * D; - const float* kptr = K + j * D; - __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < D; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } } -} - -template -static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) -{ - const int VEC_PER_UNROLL = D_UNROLL / 8; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) + for (; i + 2 <= m; i += 2) { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + int j = 0; + for (; j + 4 <= n; j += 4) { - op[mi] = O + (i + mi) * d; - pptr[mi] = P + (i + mi) * n; - } + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { - __m256 acc[M_BLOCK][VEC_PER_UNROLL]; - for (int mi = 0; mi < M_BLOCK; mi++) - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm256_loadu_ps(op[mi] + dd + vi * 8); + __m128 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); + } - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 4) { - __m256 vvec[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - vvec[vi] = _mm256_loadu_ps(V + j * d + dd + vi * 8); + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); - for (int mi = 0; mi < M_BLOCK; mi++) + for (int mi = 0; mi < 2; mi++) { - __m256 pvec = _mm256_set1_ps(pptr[mi][j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm256_comp_fmadd_ps(pvec, vvec[vi], acc[mi][vi]); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); } } - for (int mi = 0; mi < M_BLOCK; mi++) - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(op[mi] + dd + vi * 8, acc[mi][vi]); - } - - for (; dd + 7 < d; dd += 8) - { - __m256 acc[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm256_loadu_ps(op[mi] + dd); - - for (int j = 0; j < n; j++) + for (int mi = 0; mi < 2; mi++) { - __m256 vvec = _mm256_loadu_ps(V + j * d + dd); - for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; } - - for (int mi = 0; mi < M_BLOCK; mi++) - _mm256_storeu_ps(op[mi] + dd, acc[mi]); } - for (; dd < d; dd++) + for (; j + 2 <= n; j += 2) { - for (int mi = 0; mi < M_BLOCK; mi++) + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m128 acc[2][2]; + for (int mi = 0; mi < 2; mi++) { - float acc = op[mi][dd]; - for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * d + dd]; - op[mi][dd] = acc; + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); } - } - } - for (; i < m; i++) - { - float* optr = O + i * d; - const float* pptr = P + i * n; + for (int k = 0; k < 1024; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); - int dd = 0; - for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) - { - __m256 acc[VEC_PER_UNROLL]; - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_loadu_ps(optr + dd + vi * 8); + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } - for (int j = 0; j < n; j++) + for (int mi = 0; mi < 2; mi++) { - __m256 pvec = _mm256_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * d + dd + vi * 8), acc[vi]); + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; } - - for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - _mm256_storeu_ps(optr + dd + vi * 8, acc[vi]); - } - - for (; dd + 7 < d; dd += 8) - { - __m256 acc = _mm256_loadu_ps(optr + dd); - for (int j = 0; j < n; j++) - acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); - _mm256_storeu_ps(optr + dd, acc); - } - - for (; dd < d; dd++) - { - float acc = optr[dd]; - for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * d + dd]; - optr[dd] = acc; } - } -} - -template -static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) -{ - const int VEC_PER_D = D / 8; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) - { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) + for (; j < n; j++) { - op[mi] = O + (i + mi) * D; - pptr[mi] = P + (i + mi) * n; - } + const float* kptr = K + j * 1024; - for (int mi = 0; mi < M_BLOCK; mi++) - { - __m256 acc[VEC_PER_D]; - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + __m128 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm_setzero_ps(); - for (int j = 0; j < n; j++) + for (int k = 0; k < 1024; k += 4) { - __m256 pvec = _mm256_set1_ps(pptr[mi][j]); - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); - } - - for (int vi = 0; vi < VEC_PER_D; vi++) - _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); - } - } - - for (; i < m; i++) - { - float* optr = O + i * D; - const float* pptr = P + i * n; - - __m256 acc[VEC_PER_D]; - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_loadu_ps(optr + vi * 8); + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } - for (int j = 0; j < n; j++) - { - __m256 pvec = _mm256_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } - - for (int vi = 0; vi < VEC_PER_D; vi++) - _mm256_storeu_ps(optr + vi * 8, acc[vi]); } -} - -static inline void softmax_tile_avx(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - for (int i = 0; i < m; i++) + for (; i < m; i++) { - const float* sptr = S + i * n; - float* pptr = P + i * n; - - __m256 vmax = _mm256_set1_ps(m_vec[i]); int j = 0; - for (; j + 7 < n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); - float m_new = _mm256_reduce_max_ps(vmax); - for (; j < n; j++) - m_new = std::max(m_new, sptr[j]); + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; - float scale_factor = expf(m_vec[i] - m_new); - scale_out[i] = scale_factor; - l_vec[i] *= scale_factor; + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); - __m256 vm_new = _mm256_set1_ps(m_new); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < n; j += 8) + for (int k = 0; k < 1024; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); + acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); + } + + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; + } + + for (; j + 2 <= n; j += 2) { - __m256 svec = _mm256_loadu_ps(sptr + j); - __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); - _mm256_storeu_ps(pptr + j, evec); - vsum = _mm256_add_ps(vsum, evec); + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + + for (int k = 0; k < 1024; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; } - float l_add = _mm256_reduce_add_ps(vsum); + for (; j < n; j++) { - pptr[j] = expf(sptr[j] - m_new); - l_add += pptr[j]; + const float* qptr = Q + i * 1024; + const float* kptr = K + j * 1024; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < 1024; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } - l_vec[i] += l_add; - m_vec[i] = m_new; } } -static inline void vec_scale_avx(float* x, float s, int n) +template<> +void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) { - __m256 vscale = _mm256_set1_ps(s); int i = 0; - for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); - for (; i < n; i++) - x[i] *= s; -} + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; -static inline void vec_zero_avx(float* x, int n) -{ - __m256 zero = _mm256_setzero_ps(); - int i = 0; - for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, zero); - for (; i < n; i++) - x[i] = 0.f; -} + __m128 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); + } -static inline void decode_qk_dot_avx(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) -{ - int j = 0; - for (; j + 1 < block_n; j += 2) - { - const float* k0 = K + (n_start + j + 0) * d; - const float* k1 = K + (n_start + j + 1) * d; + for (int k = 0; k < 2048; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } - int k = 0; - for (; k + 7 < d; k += 8) - { - __m256 qv = _mm256_loadu_ps(q + k); - acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; + } } - float sum0 = _mm256_reduce_add_ps(acc0); - float sum1 = _mm256_reduce_add_ps(acc1); - - for (; k < d; k++) + for (; j + 2 <= n; j += 2) { - sum0 += q[k] * k0[k]; - sum1 += q[k] * k1[k]; - } + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; - s[j + 0] = sum0 * scale; - s[j + 1] = sum1 * scale; - } + __m128 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + } - for (; j < block_n; j++) - { - const float* kptr = K + (n_start + j) * d; - __m256 acc = _mm256_setzero_ps(); - int k = 0; - for (; k + 7 < d; k += 8) - acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); - float sum = _mm256_reduce_add_ps(acc); - for (; k < d; k++) - sum += q[k] * kptr[k]; - s[j] = sum * scale; - } -} + for (int k = 0; k < 2048; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); -static inline void decode_pv_gemv_avx(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) -{ - for (int j = 0; j < block_n; j++) - { - __m256 pvec = _mm256_set1_ps(s[j]); - int k = 0; - for (; k + 7 < out_d; k += 8) - { - __m256 oval = _mm256_loadu_ps(out + k); - __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); - _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + } } - for (; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; - } -} -static void sdpa_decode_avx(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; - __attribute__((aligned(32))) float s[BLOCK_N]; + for (; j < n; j++) + { + const float* kptr = K + j * 2048; - vec_zero_avx(out, out_d); + __m128 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm_setzero_ps(); - float m = -FLT_MAX; - float l = 0.f; + for (int k = 0; k < 2048; k += 4) + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } - for (int n_start = 0; n_start < n; n_start += BLOCK_N) + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) { - int block_n = std::min(BLOCK_N, n - n_start); + int j = 0; + for (; j + 4 <= n; j += 4) + { + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + const float* k2 = K + (j + 2) * 2048; + const float* k3 = K + (j + 3) * 2048; - decode_qk_dot_avx(s, q, K, n_start, block_n, d, scale); + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + __m128 acc2 = _mm_setzero_ps(); + __m128 acc3 = _mm_setzero_ps(); - if (mask) - { - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; - } + for (int k = 0; k < 2048; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); + acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); + } - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; + } - float new_m = std::max(m, tile_m); - if (m != new_m) + for (; j + 2 <= n; j += 2) { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_avx(out, scale_factor, out_d); + const float* qptr = Q + i * 2048; + const float* k0 = K + (j + 0) * 2048; + const float* k1 = K + (j + 1) * 2048; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + + for (int k = 0; k < 2048; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + } + + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; } - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) + for (; j < n; j++) { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); + const float* qptr = Q + i * 2048; + const float* kptr = K + j * 2048; + __m128 vacc = _mm_setzero_ps(); + for (int k = 0; k < 2048; k += 4) + vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); + S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) + } +} + +template<> +void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + int i = 0; + for (; i + 2 <= m; i += 2) + { + int j = 0; + for (; j + 4 <= n; j += 4) { - s[j] = expf(s[j] - new_m); - l_add += s[j]; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; + + __m128 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); + } + + for (int k = 0; k < 4096; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); + + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; + } } - l += l_add; - decode_pv_gemv_avx(out, s, V, n_start, block_n, out_d); + for (; j + 2 <= n; j += 2) + { + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; - m = new_m; - } + __m128 acc[2][2]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + } - float inv_l = 1.f / l; - vec_scale_avx(out, inv_l, out_d); -} + for (int k = 0; k < 4096; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); -#endif // __AVX__ + for (int mi = 0; mi < 2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } -#if __SSE2__ + for (int mi = 0; mi < 2; mi++) + { + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + } + } -static void qk_gemm_sse2(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) -{ - int i = 0; - for (; i + 4 <= m; i += 4) - { - int j = 0; for (; j < n; j++) { - const float* kptr = K + j * d; + const float* kptr = K + j * 4096; - __m128 acc[4]; - for (int mi = 0; mi < 4; mi++) + __m128 acc[2]; + for (int mi = 0; mi < 2; mi++) acc[mi] = _mm_setzero_ps(); - int k = 0; - for (; k + 3 < d; k += 4) + for (int k = 0; k < 4096; k += 4) { __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * d + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 4; mi++) - { - float sum = _mm_reduce_add_ps(acc[mi]); - for (; k < d; k++) - sum += Q[(i + mi) * d + k] * kptr[k]; - S[(i + mi) * n + j] = sum * scale; - } + for (int mi = 0; mi < 2; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } } for (; i < m; i++) { - for (int j = 0; j < n; j++) - { - const float* qptr = Q + i * d; - const float* kptr = K + j * d; - float sum = 0.f; - int k = 0; - __m128 vacc = _mm_setzero_ps(); - for (; k + 3 < d; k += 4) - vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); - sum = _mm_reduce_add_ps(vacc); - for (; k < d; k++) - sum += qptr[k] * kptr[k]; - S[i * n + j] = sum * scale; - } - } -} - - -template -static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 4 <= m; i += 4) - { - for (int j = 0; j < n; j++) + int j = 0; + for (; j + 4 <= n; j += 4) { - const float* kptr = K + j * D; - - const float* q0 = Q + (i + 0) * D; - const float* q1 = Q + (i + 1) * D; - const float* q2 = Q + (i + 2) * D; - const float* q3 = Q + (i + 3) * D; + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; + const float* k2 = K + (j + 2) * 4096; + const float* k3 = K + (j + 3) * 4096; __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); __m128 acc2 = _mm_setzero_ps(); __m128 acc3 = _mm_setzero_ps(); - for (int k = 0; k < D; k += 4) + for (int k = 0; k < 4096; k += 4) { - __m128 kvec = _mm_loadu_ps(kptr + k); + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); + acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); + } - __m128 qvec = _mm_loadu_ps(q0 + k); - acc0 = _mm_comp_fmadd_ps(qvec, kvec, acc0); + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; + } - qvec = _mm_loadu_ps(q1 + k); - acc1 = _mm_comp_fmadd_ps(qvec, kvec, acc1); + for (; j + 2 <= n; j += 2) + { + const float* qptr = Q + i * 4096; + const float* k0 = K + (j + 0) * 4096; + const float* k1 = K + (j + 1) * 4096; - qvec = _mm_loadu_ps(q2 + k); - acc2 = _mm_comp_fmadd_ps(qvec, kvec, acc2); + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); - qvec = _mm_loadu_ps(q3 + k); - acc3 = _mm_comp_fmadd_ps(qvec, kvec, acc3); + for (int k = 0; k < 4096; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); } - S[(i + 0) * n + j] = _mm_reduce_add_ps(acc0) * scale; - S[(i + 1) * n + j] = _mm_reduce_add_ps(acc1) * scale; - S[(i + 2) * n + j] = _mm_reduce_add_ps(acc2) * scale; - S[(i + 3) * n + j] = _mm_reduce_add_ps(acc3) * scale; + S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; } - } - for (; i < m; i++) - { - for (int j = 0; j < n; j++) + for (; j < n; j++) { - const float* qptr = Q + i * D; - const float* kptr = K + j * D; + const float* qptr = Q + i * 4096; + const float* kptr = K + j * 4096; __m128 vacc = _mm_setzero_ps(); - for (int k = 0; k < D; k += 4) + for (int k = 0; k < 4096; k += 4) vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } @@ -2511,6 +4323,45 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, #elif __SSE2__ qk_gemm_specialized_sse2<160>(S, Q, K, m, n, scale); return; +#endif + } + if (d == 1024) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 2048) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 4096) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<4096>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<4096>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<4096>(S, Q, K, m, n, scale); + return; #endif } @@ -2565,6 +4416,45 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, #elif __SSE2__ pv_gemm_sse2<2, 16>(O, P, V, m, n, d); return; +#endif + } + if (d == 1024) + { +#if __AVX512F__ + pv_gemm_avx512<2, 1024>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + return; +#endif + } + if (d == 2048) + { +#if __AVX512F__ + pv_gemm_avx512<2, 2048>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + return; +#endif + } + if (d == 4096) + { +#if __AVX512F__ + pv_gemm_avx512<2, 4096>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 32>(O, P, V, m, n, d); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + return; #endif } @@ -2968,10 +4858,12 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat o_accum(out_embed_dim, BLOCK_M, num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + const bool large_dim = embed_dim > 512; + Mat q_batch(embed_dim, large_dim ? BLOCK_M : BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - if (s_vec.empty() || o_accum.empty()) + if (s_vec.empty() || o_accum.empty() || q_batch.empty()) return -100; int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; @@ -2996,6 +4888,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); + Mat q_batch_thread = q_batch.channel(get_omp_thread_num()); // Pre-resolve mask pointers for all heads in this group const float* mask_data[num_heads_per_group]; @@ -3018,6 +4911,20 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat m_state_tile = m_state.channel(idx); Mat l_state_tile = l_state.channel(idx); + if (!large_dim) + { + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + float* q_dst = q_batch_thread.row(hq * block_m); + for (int i = 0; i < block_m; i++) + { + memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); + } + } + } + // Initialize softmax state and zero output accumulator for all Q heads in this group for (int hq = 0; hq < num_heads_per_group; hq++) { @@ -3029,8 +4936,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to l_vec[i] = 0.f; } - float* o_ptr = o_accum_thread.channel(hq).row(0); - vec_zero_dispatch(o_ptr, out_embed_dim * BLOCK_M); + float* o_ptr = o_accum_thread.row(hq * block_m); + vec_zero_dispatch(o_ptr, out_embed_dim * block_m); } // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group @@ -3039,81 +4946,162 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; - // Phase A + B fused per head: reuse compact s_vec/p_vec buffer - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); - Mat top_blob_head = top_blob.channel(q); - - float* s_ptr = s_vec_thread.row(0); + float* s_ptr = s_vec_thread.row(0); + if (!large_dim) + { qk_gemm_dispatch(s_ptr, - query_head.row(m_start), + q_batch_thread.row(0), key_head.row(n_start), - block_m, block_n, embed_dim, _scale); + block_m * num_heads_per_group, block_n, embed_dim, _scale); - // Apply attention mask - if (attn_mask && mask_data[hq]) + for (int hq = 0; hq < num_heads_per_group; hq++) { - for (int i = 0; i < block_m; i++) + float* s_head = s_ptr + hq * block_m * block_n; + + if (attn_mask && mask_data[hq]) { - const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; - float* sptr = s_ptr + i * block_n; - int j = 0; -#if __AVX512F__ - for (; j + 15 < block_n; j += 16) + for (int i = 0; i < block_m; i++) { - _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; + float* sptr = s_head + i * block_n; + int j = 0; +#if __AVX512F__ + for (; j + 15 < block_n; j += 16) + { + _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); + _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); + j = block_n; + } +#elif __AVX__ + for (; j + 7 < block_n; j += 8) + { + _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); + } +#elif __SSE2__ + for (; j + 3 < block_n; j += 4) + { + _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); + } +#endif + for (; j < block_n; j++) + { + sptr[j] += mptr[j]; + } } - if (j < block_n) + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; + } + softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (m_old[i] != m_vec[i]) { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); - _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); - j = block_n; + vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } -#elif __AVX__ - for (; j + 7 < block_n; j += 8) + } + } + + pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } + else + { + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + + float* q_dst = q_batch_thread.row(0); + for (int i = 0; i < block_m; i++) + { + memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); + } + + float* s_head = s_ptr; + + qk_gemm_dispatch(s_head, + q_dst, + key_head.row(n_start), + block_m, block_n, embed_dim, _scale); + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) { - _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); - } + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; + float* sptr = s_head + i * block_n; + int j = 0; +#if __AVX512F__ + for (; j + 15 < block_n; j += 16) + { + _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); + _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); + j = block_n; + } +#elif __AVX__ + for (; j + 7 < block_n; j += 8) + { + _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); + } #elif __SSE2__ - for (; j + 3 < block_n; j += 4) - { - _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); - } + for (; j + 3 < block_n; j += 4) + { + _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); + } #endif - for (; j < block_n; j++) - { - sptr[j] += mptr[j]; + for (; j < block_n; j++) + { + sptr[j] += mptr[j]; + } } } - } - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.channel(hq).row(0); + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); - float m_old[BLOCK_M]; - float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile_dispatch(s_ptr, s_ptr, m_vec, l_vec, scale_factors, block_m, block_n); + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; + } + softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); - // Rescale O accumulator when max increases - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) + for (int i = 0; i < block_m; i++) { - vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + if (m_old[i] != m_vec[i]) + { + vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } } - } - pv_gemm_dispatch(o_ptr, s_ptr, value_head.row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start), + block_m, block_n, out_embed_dim); + } } } @@ -3123,7 +5111,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int q = g * num_heads_per_group + hq; Mat top_blob_head = top_blob.channel(q); float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.channel(hq).row(0); + float* o_ptr = o_accum_thread.row(hq * block_m); for (int i = 0; i < block_m; i++) { From 5de27cb9d612665ccd5cc82ff6d6468e879462d0 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 27 Apr 2026 01:39:02 +0800 Subject: [PATCH 15/53] import dim 512 spec --- src/layer/x86/sdpa_x86.cpp | 37 +++++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 4adc6ab940ba..b7854328dc31 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -4385,10 +4385,10 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 128>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 128>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 128>(O, P, V, m, n); return; #endif } @@ -4398,10 +4398,10 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<4, 64>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<4, 64>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<4, 64>(O, P, V, m, n); return; #endif } @@ -4411,10 +4411,23 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 256>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 256>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 256>(O, P, V, m, n); + return; +#endif + } + if (d == 512) + { +#if __AVX512F__ + pv_gemm_avx512<2, 512>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 512>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 512>(O, P, V, m, n); return; #endif } @@ -4424,10 +4437,10 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 1024>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 1024>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 1024>(O, P, V, m, n); return; #endif } @@ -4437,10 +4450,10 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 2048>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 2048>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 2048>(O, P, V, m, n); return; #endif } @@ -4450,10 +4463,10 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, pv_gemm_avx512<2, 4096>(O, P, V, m, n); return; #elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); + pv_gemm_avx<2, 4096>(O, P, V, m, n); return; #elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); + pv_gemm_sse2<2, 4096>(O, P, V, m, n); return; #endif } From 61b375ad467eff0e6ce68061fecc2714f8c688fa Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 27 Apr 2026 23:10:25 +0800 Subject: [PATCH 16/53] slim kernel --- src/layer/x86/sdpa_x86.cpp | 2559 +++++++++++++----------------------- 1 file changed, 881 insertions(+), 1678 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index b7854328dc31..593d8193431d 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -99,6 +99,197 @@ static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, } } +static inline void vec_scale(float* x, float s, int n) +{ +#if __AVX512F__ + __m512 vscale = _mm512_set1_ps(s); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); + } +#elif __AVX__ + __m256 vscale = _mm256_set1_ps(s); + int i = 0; + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; +#elif __SSE2__ + __m128 vscale = _mm_set1_ps(s); + int i = 0; + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale)); + for (; i < n; i++) + x[i] *= s; +#else + for (int i = 0; i < n; i++) + x[i] *= s; +#endif +} + +static inline void vec_zero(float* x, int n) +{ +#if __AVX512F__ + __m512 zero = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < n; i += 16) + _mm512_storeu_ps(x + i, zero); + if (i < n) + { + __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); + _mm512_mask_storeu_ps(x + i, mask, zero); + } +#elif __AVX__ + __m256 zero = _mm256_setzero_ps(); + int i = 0; + for (; i + 7 < n; i += 8) + _mm256_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; +#elif __SSE2__ + __m128 zero = _mm_setzero_ps(); + int i = 0; + for (; i + 3 < n; i += 4) + _mm_storeu_ps(x + i, zero); + for (; i < n; i++) + x[i] = 0.f; +#else + for (int i = 0; i < n; i++) + x[i] = 0.f; +#endif +} + +static inline void softmax_tile(float* P, const float* S, + float* m_vec, float* l_vec, float* scale_out, int m, int n) +{ + for (int i = 0; i < m; i++) + { + const float* sptr = S + i * n; + float* pptr = P + i * n; + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); + vmax = _mm512_max_ps(vmax, tail); + } + float m_new = _mm512_comp_reduce_max_ps(vmax); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + l_vec[i] *= scale_factor; + + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < n; j += 16) + { + __m512 svec = _mm512_loadu_ps(sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_storeu_ps(pptr + j, evec); + vsum = _mm512_add_ps(vsum, evec); + } + if (j < n) + { + __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); + __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); + __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); + _mm512_mask_storeu_ps(pptr + j, mask, evec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); + } + float l_add = _mm512_comp_reduce_add_ps(vsum); + l_vec[i] += l_add; + m_vec[i] = m_new; +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + l_vec[i] *= scale_factor; + + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < n; j += 8) + { + __m256 svec = _mm256_loadu_ps(sptr + j); + __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); + _mm256_storeu_ps(pptr + j, evec); + vsum = _mm256_add_ps(vsum, evec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(sptr + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < n; j++) + m_new = std::max(m_new, sptr[j]); + + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + l_vec[i] *= scale_factor; + + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < n; j += 4) + { + __m128 svec = _mm_loadu_ps(sptr + j); + __m128 evec = exp_ps(_mm_sub_ps(svec, vm_new)); + _mm_storeu_ps(pptr + j, evec); + vsum = _mm_add_ps(vsum, evec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#else + float m_new = m_vec[i]; + for (int j = 0; j < n; j++) + m_new = std::max(m_new, sptr[j]); + float scale_factor = expf(m_vec[i] - m_new); + scale_out[i] = scale_factor; + l_vec[i] *= scale_factor; + float l_add = 0.f; + for (int j = 0; j < n; j++) + { + pptr[j] = expf(sptr[j] - m_new); + l_add += pptr[j]; + } + l_vec[i] += l_add; + m_vec[i] = m_new; +#endif + } +} + static inline void pv_gemm_scalar(float* O, const float* P, const float* V, int m, int n, int d) { @@ -116,24 +307,191 @@ static inline void pv_gemm_scalar(float* O, const float* P, const float* V, } } -static inline void vec_scale_scalar(float* x, float s, int n) +static inline void decode_pv_gemv(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) { - for (int i = 0; i < n; i++) x[i] *= s; + for (int j = 0; j < block_n; j++) + { +#if __AVX512F__ + __m512 pvec = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 15 < out_d; k += 16) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + } +#elif __AVX__ + __m256 pvec = _mm256_set1_ps(s[j]); + int k = 0; + for (; k + 7 < out_d; k += 8) + { + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; +#elif __SSE2__ + __m128 pvec = _mm_set1_ps(s[j]); + int k = 0; + for (; k + 3 < out_d; k += 4) + { + __m128 oval = _mm_loadu_ps(out + k); + __m128 vval = _mm_loadu_ps(V + (n_start + j) * out_d + k); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; +#else + for (int k = 0; k < out_d; k++) + out[k] += s[j] * V[(n_start + j) * out_d + k]; +#endif + } } -static inline void vec_zero_scalar(float* x, int n) +static inline void decode_qk_dot(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) { - for (int i = 0; i < n; i++) x[i] = 0.f; +#if __AVX512F__ + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +#elif __AVX__ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + sum0 += q[k] * k0[k]; + sum1 += q[k] * k1[k]; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + + for (; j < block_n; j++) + { + const float* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); + int k = 0; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * kptr[k]; + s[j] = sum * scale; + } +#elif __SSE2__ + for (int j = 0; j < block_n; j++) + { + __m128 acc = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), _mm_loadu_ps(K + (n_start + j) * d + k), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } +#else + for (int j = 0; j < block_n; j++) + { + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += q[k] * K[(n_start + j) * d + k]; + s[j] = sum * scale; + } +#endif } -static void sdpa_decode_scalar(float* out, const float* q, +static void sdpa_decode(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) { const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else float s[BLOCK_N]; +#endif - for (int k = 0; k < out_d; k++) out[k] = 0.f; + vec_zero(out, out_d); float m = -FLT_MAX; float l = 0.f; @@ -142,33 +500,132 @@ static void sdpa_decode_scalar(float* out, const float* q, { int block_n = std::min(BLOCK_N, n - n_start); + decode_qk_dot(s, q, K, n_start, block_n, d, scale); + + if (mask) + { +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[n_start + j]; +#endif + } + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float tile_m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#else + float tile_m = -FLT_MAX; for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#endif + + float new_m = std::max(m, tile_m); + if (m != new_m) { - float sum = 0.f; - for (int k = 0; k < d; k++) - sum += q[k] * K[(n_start + j) * d + k]; - s[j] = sum * scale; + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(out, scale_factor, out_d); } - if (mask) +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) { - for (int j = 0; j < block_n; j++) - s[j] += mask[n_start + j]; + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); } - - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); - - float new_m = std::max(m, tile_m); - if (m != new_m) + float l_add = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) { - float scale_factor = expf(m - new_m); - l *= scale_factor; - for (int k = 0; k < out_d; k++) - out[k] *= scale_factor; + s[j] = expf(s[j] - new_m); + l_add += s[j]; } - + l += l_add; +#else float l_add = 0.f; for (int j = 0; j < block_n; j++) { @@ -176,42 +633,15 @@ static void sdpa_decode_scalar(float* out, const float* q, l_add += s[j]; } l += l_add; +#endif - for (int j = 0; j < block_n; j++) - { - for (int k = 0; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; - } + decode_pv_gemv(out, s, V, n_start, block_n, out_d); m = new_m; } float inv_l = 1.f / l; - for (int k = 0; k < out_d; k++) - out[k] *= inv_l; -} - -static inline void softmax_tile_scalar(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - for (int i = 0; i < m; i++) - { - const float* sptr = S + i * n; - float* pptr = P + i * n; - float m_new = m_vec[i]; - for (int j = 0; j < n; j++) m_new = std::max(m_new, sptr[j]); - float scale_factor = expf(m_vec[i] - m_new); - scale_out[i] = scale_factor; - l_vec[i] *= scale_factor; - float l_add = 0.f; - for (int j = 0; j < n; j++) - { - pptr[j] = expf(sptr[j] - m_new); - l_add += pptr[j]; - } - l_vec[i] += l_add; - m_vec[i] = m_new; - } + vec_scale(out, inv_l, out_d); } #if __AVX512F__ @@ -580,221 +1010,15 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl const float* qptr = Q + i * D; const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; - const float* k2 = K + (j + 2) * D; - const float* k3 = K + (j + 3) * D; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); - - for (int k = 0; k < D; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * D; - const float* kptr = K + j * D; - __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < D; k += 16) - vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; - } - } -} - -// Explicit specialization for D=128: 6x4 kernel to improve K-tile reuse -template<> -void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 6 <= m; i += 6) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - const float* k2 = K + (j + 2) * 128; - const float* k3 = K + (j + 3) * 128; - - __m512 acc[6][4]; - for (int mi = 0; mi < 6; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - acc[mi][2] = _mm512_setzero_ps(); - acc[mi][3] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - __m512 kv2 = _mm512_loadu_ps(k2 + k); - __m512 kv3 = _mm512_loadu_ps(k3 + k); - - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); - } - } - - for (int mi = 0; mi < 6; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc[6][2]; - for (int mi = 0; mi < 6; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 6; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - const float* kptr = K + j * 128; - - __m512 acc[6]; - for (int mi = 0; mi < 6; mi++) - acc[mi] = _mm512_setzero_ps(); - - for (int k = 0; k < 128; k += 16) - { - __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 6; mi++) - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i + 4 <= m; i += 4) - { - int j = 0; - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc[4][2]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - - for (int mi = 0; mi < 4; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - __m512 acc[4]; - for (int mi = 0; mi < 4; mi++) - acc[mi] = _mm512_setzero_ps(); - - const float* kptr = K + j * 128; - for (int k = 0; k < 128; k += 16) - { - __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 4; mi++) - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 128; - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - const float* k2 = K + (j + 2) * 128; - const float* k3 = K + (j + 3) * 128; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); __m512 acc2 = _mm512_setzero_ps(); __m512 acc3 = _mm512_setzero_ps(); - for (int k = 0; k < 128; k += 16) + for (int k = 0; k < D; k += 16) { __m512 qvec = _mm512_loadu_ps(qptr + k); acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); @@ -809,55 +1033,36 @@ void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; } - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 128; - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - - for (int k = 0; k < 128; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - } - for (; j < n; j++) { - const float* qptr = Q + i * 128; - const float* kptr = K + j * 128; + const float* qptr = Q + i * D; + const float* kptr = K + j * D; __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < 128; k += 16) + for (int k = 0; k < D; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } } } +// Explicit specialization for D=128: 6x4 kernel to improve K-tile reuse template<> -void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, +void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; - for (; i + 4 <= m; i += 4) + for (; i + 6 <= m; i += 6) { int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + const float* k2 = K + (j + 2) * 128; + const float* k3 = K + (j + 3) * 128; - __m512 acc[4][4]; - for (int mi = 0; mi < 4; mi++) + __m512 acc[6][4]; + for (int mi = 0; mi < 6; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); @@ -865,16 +1070,16 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, acc[mi][3] = _mm512_setzero_ps(); } - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); __m512 kv2 = _mm512_loadu_ps(k2 + k); __m512 kv3 = _mm512_loadu_ps(k3 + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -882,7 +1087,7 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, } } - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -893,30 +1098,30 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; - __m512 acc[4][2]; - for (int mi = 0; mi < 4; mi++) + __m512 acc[6][2]; + for (int mi = 0; mi < 6; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); } - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -925,98 +1130,56 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 1024; + const float* kptr = K + j * 128; - __m512 acc[4]; - for (int mi = 0; mi < 4; mi++) + __m512 acc[6]; + for (int mi = 0; mi < 6; mi++) acc[mi] = _mm512_setzero_ps(); - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < 6; mi++) S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } - for (; i + 2 <= m; i += 2) + for (; i + 4 <= m; i += 4) { int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m512 acc[2][4]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - acc[mi][2] = _mm512_setzero_ps(); - acc[mi][3] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - __m512 kv2 = _mm512_loadu_ps(k2 + k); - __m512 kv3 = _mm512_loadu_ps(k3 + k); - - for (int mi = 0; mi < 2; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); - } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; - } - } - for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; - __m512 acc[2][2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); } - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < 4; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < 4; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -1025,23 +1188,22 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 1024; - - __m512 acc[2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) acc[mi] = _mm512_setzero_ps(); - for (int k = 0; k < 1024; k += 16) + const float* kptr = K + j * 128; + for (int k = 0; k < 128; k += 16) { __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < 4; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 1024 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < 4; mi++) S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } @@ -1051,18 +1213,18 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; + const float* qptr = Q + i * 128; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; + const float* k2 = K + (j + 2) * 128; + const float* k3 = K + (j + 3) * 128; __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); __m512 acc2 = _mm512_setzero_ps(); __m512 acc3 = _mm512_setzero_ps(); - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 qvec = _mm512_loadu_ps(qptr + k); acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); @@ -1079,14 +1241,14 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; + const float* qptr = Q + i * 128; + const float* k0 = K + (j + 0) * 128; + const float* k1 = K + (j + 1) * 128; __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) { __m512 qvec = _mm512_loadu_ps(qptr + k); acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); @@ -1099,33 +1261,36 @@ void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 1024; - const float* kptr = K + j * 1024; + const float* qptr = Q + i * 128; + const float* kptr = K + j * 128; __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < 1024; k += 16) + for (int k = 0; k < 128; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } } } -template<> -void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, + +template +static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; - for (; i + 2 <= m; i += 2) + if (M_BLOCK2 > 0) + { + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) { int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; - __m512 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); @@ -1133,16 +1298,16 @@ void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, acc[mi][3] = _mm512_setzero_ps(); } - for (int k = 0; k < 2048; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); __m512 kv2 = _mm512_loadu_ps(k2 + k); __m512 kv3 = _mm512_loadu_ps(k3 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -1150,7 +1315,7 @@ void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -1161,30 +1326,30 @@ void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; - __m512 acc[2][2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); } - for (int k = 0; k < 2048; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -1193,107 +1358,40 @@ void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 2048; + const float* kptr = K + j * D; - __m512 acc[2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) acc[mi] = _mm512_setzero_ps(); - for (int k = 0; k < 2048; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 2048 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); - - for (int k = 0; k < 2048; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - - for (int k = 0; k < 2048; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 2048; - const float* kptr = K + j * 2048; - __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < 2048; k += 16) - vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; - } } -} -template<> -void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 2 <= m; i += 2) + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) { int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; - __m512 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); @@ -1301,16 +1399,16 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, acc[mi][3] = _mm512_setzero_ps(); } - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); __m512 kv2 = _mm512_loadu_ps(k2 + k); __m512 kv3 = _mm512_loadu_ps(k3 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -1318,7 +1416,7 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -1329,30 +1427,30 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; - __m512 acc[2][2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm512_setzero_ps(); acc[mi][1] = _mm512_setzero_ps(); } - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kv0 = _mm512_loadu_ps(k0 + k); __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; @@ -1361,23 +1459,23 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 4096; + const float* kptr = K + j * D; - __m512 acc[2]; - for (int mi = 0; mi < 2; mi++) + __m512 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) acc[mi] = _mm512_setzero_ps(); - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) { __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 4096 + k); + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; } } @@ -1387,18 +1485,18 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); __m512 acc2 = _mm512_setzero_ps(); __m512 acc3 = _mm512_setzero_ps(); - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) { __m512 qvec = _mm512_loadu_ps(qptr + k); acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); @@ -1415,14 +1513,14 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) { __m512 qvec = _mm512_loadu_ps(qptr + k); acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); @@ -1435,10 +1533,10 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 4096; - const float* kptr = K + j * 4096; + const float* qptr = Q + i * D; + const float* kptr = K + j * D; __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < 4096; k += 16) + for (int k = 0; k < D; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } @@ -1446,6 +1544,30 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, } +template<> +void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + + qk_gemm_specialized_tiled_avx512<2048, 2, 0>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<1024, 2, 4>(S, Q, K, m, n, scale); +} + + + +template<> +void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<4096, 2, 0>(S, Q, K, m, n, scale); +} + + template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) { @@ -1669,299 +1791,71 @@ static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int { float acc = op[mi][dd]; for (int j = 0; j < n; j++) - acc += pptr[mi][j] * V[j * D + dd]; - op[mi][dd] = acc; - } - } - } - - for (; i < m; i++) - { - float* optr = O + i * D; - const float* pptr = P + i * n; - - int dd = 0; - for (; dd + 127 < D; dd += 128) - { - __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); - __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); - __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); - __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); - __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); - __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); - __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); - __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); - - for (int j = 0; j < n; j++) - { - __m512 pvec = _mm512_set1_ps(pptr[j]); - acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); - acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); - acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); - acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); - acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); - acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); - acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); - acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); - } - - _mm512_storeu_ps(optr + dd + 0 * 16, acc0); - _mm512_storeu_ps(optr + dd + 1 * 16, acc1); - _mm512_storeu_ps(optr + dd + 2 * 16, acc2); - _mm512_storeu_ps(optr + dd + 3 * 16, acc3); - _mm512_storeu_ps(optr + dd + 4 * 16, acc4); - _mm512_storeu_ps(optr + dd + 5 * 16, acc5); - _mm512_storeu_ps(optr + dd + 6 * 16, acc6); - _mm512_storeu_ps(optr + dd + 7 * 16, acc7); - } - - for (; dd + 15 < D; dd += 16) - { - __m512 acc = _mm512_loadu_ps(optr + dd); - for (int j = 0; j < n; j++) - acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); - _mm512_storeu_ps(optr + dd, acc); - } - - for (; dd < D; dd++) - { - float acc = optr[dd]; - for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * D + dd]; - optr[dd] = acc; - } - } -} - - -static inline void softmax_tile_avx512(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - for (int i = 0; i < m; i++) - { - const float* sptr = S + i * n; - float* pptr = P + i * n; - - __m512 vmax = _mm512_set1_ps(m_vec[i]); - int j = 0; - for (; j + 15 < n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(sptr + j)); - if (j < n) - { - __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); - __m512 tail = _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, sptr + j); - vmax = _mm512_max_ps(vmax, tail); - } - float m_new = _mm512_comp_reduce_max_ps(vmax); - - float scale_factor = expf(m_vec[i] - m_new); - scale_out[i] = scale_factor; - l_vec[i] *= scale_factor; - - __m512 vm_new = _mm512_set1_ps(m_new); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < n; j += 16) - { - __m512 svec = _mm512_loadu_ps(sptr + j); - __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); - _mm512_storeu_ps(pptr + j, evec); - vsum = _mm512_add_ps(vsum, evec); - } - if (j < n) - { - __mmask16 mask = (__mmask16)((1u << (n - j)) - 1); - __m512 svec = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 evec = exp512_ps(_mm512_sub_ps(svec, vm_new)); - _mm512_mask_storeu_ps(pptr + j, mask, evec); - vsum = _mm512_mask_add_ps(vsum, mask, vsum, evec); - } - float l_add = _mm512_comp_reduce_add_ps(vsum); - l_vec[i] += l_add; - m_vec[i] = m_new; - } -} - -static inline void vec_scale_avx512(float* x, float s, int n) -{ - __m512 vscale = _mm512_set1_ps(s); - int i = 0; - for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); - if (i < n) - { - __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); - } -} - -static inline void vec_zero_avx512(float* x, int n) -{ - __m512 zero = _mm512_setzero_ps(); - int i = 0; - for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, zero); - if (i < n) - { - __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, zero); - } -} - -static inline void decode_qk_dot_avx512(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) -{ - int j = 0; - for (; j + 3 < block_n; j += 4) - { - const float* k0 = K + (n_start + j + 0) * d; - const float* k1 = K + (n_start + j + 1) * d; - const float* k2 = K + (n_start + j + 2) * d; - const float* k3 = K + (n_start + j + 3) * d; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); - - int k = 0; - for (; k + 15 < d; k += 16) - { - __m512 qv = _mm512_loadu_ps(q + k); - acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); - } - if (k < d) - { - __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); - __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); - acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); - } - - s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; - } - - for (; j < block_n; j++) - { - __m512 acc = _mm512_setzero_ps(); - int k = 0; - for (; k + 15 < d; k += 16) - acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), _mm512_loadu_ps(K + (n_start + j) * d + k), acc); - if (k < d) - { - __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); - acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), _mm512_maskz_loadu_ps(mask_d, K + (n_start + j) * d + k), acc); - } - s[j] = _mm512_comp_reduce_add_ps(acc) * scale; - } -} - -static inline void decode_pv_gemv_avx512(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) -{ - for (int j = 0; j < block_n; j++) - { - __m512 pvec = _mm512_set1_ps(s[j]); - int k = 0; - for (; k + 15 < out_d; k += 16) - { - __m512 oval = _mm512_loadu_ps(out + k); - __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); - _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); - } - if (k < out_d) - { - __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); - __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); - __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); - _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + acc += pptr[mi][j] * V[j * D + dd]; + op[mi][dd] = acc; + } } } -} - -static void sdpa_decode_avx512(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; - __attribute__((aligned(64))) float s[BLOCK_N]; - - vec_zero_avx512(out, out_d); - - float m = -FLT_MAX; - float l = 0.f; - for (int n_start = 0; n_start < n; n_start += BLOCK_N) + for (; i < m; i++) { - int block_n = std::min(BLOCK_N, n - n_start); - - decode_qk_dot_avx512(s, q, K, n_start, block_n, d, scale); + float* optr = O + i * D; + const float* pptr = P + i * n; - if (mask) + int dd = 0; + for (; dd + 127 < D; dd += 128) { - int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); - if (j < block_n) + __m512 acc0 = _mm512_loadu_ps(optr + dd + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + dd + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + dd + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + dd + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + dd + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + dd + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + dd + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + dd + 7 * 16); + + for (int j = 0; j < n; j++) { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 0 * 16), acc0); + acc1 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 1 * 16), acc1); + acc2 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 2 * 16), acc2); + acc3 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 3 * 16), acc3); + acc4 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 4 * 16), acc4); + acc5 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 5 * 16), acc5); + acc6 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 6 * 16), acc6); + acc7 = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * D + dd + 7 * 16), acc7); } - } - __m512 vmax = _mm512_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + _mm512_storeu_ps(optr + dd + 0 * 16, acc0); + _mm512_storeu_ps(optr + dd + 1 * 16, acc1); + _mm512_storeu_ps(optr + dd + 2 * 16, acc2); + _mm512_storeu_ps(optr + dd + 3 * 16, acc3); + _mm512_storeu_ps(optr + dd + 4 * 16, acc4); + _mm512_storeu_ps(optr + dd + 5 * 16, acc5); + _mm512_storeu_ps(optr + dd + 6 * 16, acc6); + _mm512_storeu_ps(optr + dd + 7 * 16, acc7); } - float tile_m = _mm512_comp_reduce_max_ps(vmax); - float new_m = std::max(m, tile_m); - if (m != new_m) + for (; dd + 15 < D; dd += 16) { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_avx512(out, scale_factor, out_d); + __m512 acc = _mm512_loadu_ps(optr + dd); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * D + dd), acc); + _mm512_storeu_ps(optr + dd, acc); } - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) + for (; dd < D; dd++) { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * D + dd]; + optr[dd] = acc; } - l += _mm512_comp_reduce_add_ps(vsum); - - decode_pv_gemv_avx512(out, s, V, n_start, block_n, out_d); - - m = new_m; } - - float inv_l = 1.f / l; - vec_scale_avx512(out, inv_l, out_d); } + #endif // __AVX512F__ #if __AVX__ @@ -2406,185 +2300,17 @@ void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, for (; j + 4 <= n; j += 4) { const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - __m256 acc2 = _mm256_setzero_ps(); - __m256 acc3 = _mm256_setzero_ps(); - - for (int k = 0; k < 1024; k += 8) - { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); - acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); - acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - - for (int k = 0; k < 1024; k += 8) - { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); - } - - S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 1024; - const float* kptr = K + j * 1024; - __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < 1024; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; - } - } -} - -template<> -void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 2 <= m; i += 2) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; - - __m256 acc[2][4]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - acc[mi][2] = _mm256_setzero_ps(); - acc[mi][3] = _mm256_setzero_ps(); - } - - for (int k = 0; k < 2048; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - __m256 kv2 = _mm256_loadu_ps(k2 + k); - __m256 kv3 = _mm256_loadu_ps(k3 + k); - - for (int mi = 0; mi < 2; mi++) - { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); - } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - - __m256 acc[2][2]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - } - - for (int k = 0; k < 2048; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - - for (int mi = 0; mi < 2; mi++) - { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - const float* kptr = K + j * 2048; - - __m256 acc[2]; - for (int mi = 0; mi < 2; mi++) - acc[mi] = _mm256_setzero_ps(); - - for (int k = 0; k < 2048; k += 8) - { - __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) - { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 2048 + k); - acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 2; mi++) - S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; __m256 acc0 = _mm256_setzero_ps(); __m256 acc1 = _mm256_setzero_ps(); __m256 acc2 = _mm256_setzero_ps(); __m256 acc3 = _mm256_setzero_ps(); - for (int k = 0; k < 2048; k += 8) + for (int k = 0; k < 1024; k += 8) { __m256 qvec = _mm256_loadu_ps(qptr + k); acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); @@ -2601,14 +2327,14 @@ void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; __m256 acc0 = _mm256_setzero_ps(); __m256 acc1 = _mm256_setzero_ps(); - for (int k = 0; k < 2048; k += 8) + for (int k = 0; k < 1024; k += 8) { __m256 qvec = _mm256_loadu_ps(qptr + k); acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); @@ -2621,18 +2347,18 @@ void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 2048; - const float* kptr = K + j * 2048; + const float* qptr = Q + i * 1024; + const float* kptr = K + j * 1024; __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < 2048; k += 8) + for (int k = 0; k < 1024; k += 8) vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; } } } -template<> -void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, +template +static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; @@ -2641,10 +2367,10 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m256 acc[2][4]; for (int mi = 0; mi < 2; mi++) @@ -2655,7 +2381,7 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, acc[mi][3] = _mm256_setzero_ps(); } - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) { __m256 kv0 = _mm256_loadu_ps(k0 + k); __m256 kv1 = _mm256_loadu_ps(k1 + k); @@ -2664,7 +2390,7 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, for (int mi = 0; mi < 2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -2683,8 +2409,8 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; __m256 acc[2][2]; for (int mi = 0; mi < 2; mi++) @@ -2693,14 +2419,14 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, acc[mi][1] = _mm256_setzero_ps(); } - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) { __m256 kv0 = _mm256_loadu_ps(k0 + k); __m256 kv1 = _mm256_loadu_ps(k1 + k); for (int mi = 0; mi < 2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } @@ -2715,18 +2441,18 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 4096; + const float* kptr = K + j * D; __m256 acc[2]; for (int mi = 0; mi < 2; mi++) acc[mi] = _mm256_setzero_ps(); - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) { __m256 kvec = _mm256_loadu_ps(kptr + k); for (int mi = 0; mi < 2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 4096 + k); + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); } } @@ -2741,18 +2467,18 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m256 acc0 = _mm256_setzero_ps(); __m256 acc1 = _mm256_setzero_ps(); __m256 acc2 = _mm256_setzero_ps(); __m256 acc3 = _mm256_setzero_ps(); - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) { __m256 qvec = _mm256_loadu_ps(qptr + k); acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); @@ -2769,14 +2495,14 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; __m256 acc0 = _mm256_setzero_ps(); __m256 acc1 = _mm256_setzero_ps(); - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) { __m256 qvec = _mm256_loadu_ps(qptr + k); acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); @@ -2789,10 +2515,10 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 4096; - const float* kptr = K + j * 4096; + const float* qptr = Q + i * D; + const float* kptr = K + j * D; __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < 4096; k += 8) + for (int k = 0; k < D; k += 8) vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; } @@ -2800,6 +2526,22 @@ void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, } +template<> +void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<2048>(S, Q, K, m, n, scale); +} + + +template<> +void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<4096>(S, Q, K, m, n, scale); +} + + template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) { @@ -2900,267 +2642,74 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, for (int j = 0; j < n; j++) acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), _mm256_loadu_ps(V + j * d + dd), acc); _mm256_storeu_ps(optr + dd, acc); - } - - for (; dd < d; dd++) - { - float acc = optr[dd]; - for (int j = 0; j < n; j++) - acc += pptr[j] * V[j * d + dd]; - optr[dd] = acc; - } - } -} - - -template -static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) -{ - const int VEC_PER_D = D / 8; - int i = 0; - for (; i + M_BLOCK <= m; i += M_BLOCK) - { - float* op[M_BLOCK]; - const float* pptr[M_BLOCK]; - for (int mi = 0; mi < M_BLOCK; mi++) - { - op[mi] = O + (i + mi) * D; - pptr[mi] = P + (i + mi) * n; - } - - for (int mi = 0; mi < M_BLOCK; mi++) - { - __m256 acc[VEC_PER_D]; - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); - - for (int j = 0; j < n; j++) - { - __m256 pvec = _mm256_set1_ps(pptr[mi][j]); - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); - } - - for (int vi = 0; vi < VEC_PER_D; vi++) - _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); - } - } - - for (; i < m; i++) - { - float* optr = O + i * D; - const float* pptr = P + i * n; - - __m256 acc[VEC_PER_D]; - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_loadu_ps(optr + vi * 8); - - for (int j = 0; j < n; j++) - { - __m256 pvec = _mm256_set1_ps(pptr[j]); - for (int vi = 0; vi < VEC_PER_D; vi++) - acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); - } - - for (int vi = 0; vi < VEC_PER_D; vi++) - _mm256_storeu_ps(optr + vi * 8, acc[vi]); - } -} - - -static inline void softmax_tile_avx(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - for (int i = 0; i < m; i++) - { - const float* sptr = S + i * n; - float* pptr = P + i * n; - - __m256 vmax = _mm256_set1_ps(m_vec[i]); - int j = 0; - for (; j + 7 < n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(sptr + j)); - float m_new = _mm256_reduce_max_ps(vmax); - for (; j < n; j++) - m_new = std::max(m_new, sptr[j]); - - float scale_factor = expf(m_vec[i] - m_new); - scale_out[i] = scale_factor; - l_vec[i] *= scale_factor; - - __m256 vm_new = _mm256_set1_ps(m_new); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < n; j += 8) - { - __m256 svec = _mm256_loadu_ps(sptr + j); - __m256 evec = exp256_ps(_mm256_sub_ps(svec, vm_new)); - _mm256_storeu_ps(pptr + j, evec); - vsum = _mm256_add_ps(vsum, evec); - } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < n; j++) - { - pptr[j] = expf(sptr[j] - m_new); - l_add += pptr[j]; - } - l_vec[i] += l_add; - m_vec[i] = m_new; - } -} - -static inline void vec_scale_avx(float* x, float s, int n) -{ - __m256 vscale = _mm256_set1_ps(s); - int i = 0; - for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); - for (; i < n; i++) - x[i] *= s; -} - -static inline void vec_zero_avx(float* x, int n) -{ - __m256 zero = _mm256_setzero_ps(); - int i = 0; - for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, zero); - for (; i < n; i++) - x[i] = 0.f; -} - -static inline void decode_qk_dot_avx(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) -{ - int j = 0; - for (; j + 1 < block_n; j += 2) - { - const float* k0 = K + (n_start + j + 0) * d; - const float* k1 = K + (n_start + j + 1) * d; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - - int k = 0; - for (; k + 7 < d; k += 8) - { - __m256 qv = _mm256_loadu_ps(q + k); - acc0 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qv, _mm256_loadu_ps(k1 + k), acc1); - } - - float sum0 = _mm256_reduce_add_ps(acc0); - float sum1 = _mm256_reduce_add_ps(acc1); - - for (; k < d; k++) - { - sum0 += q[k] * k0[k]; - sum1 += q[k] * k1[k]; - } - - s[j + 0] = sum0 * scale; - s[j + 1] = sum1 * scale; - } - - for (; j < block_n; j++) - { - const float* kptr = K + (n_start + j) * d; - __m256 acc = _mm256_setzero_ps(); - int k = 0; - for (; k + 7 < d; k += 8) - acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), _mm256_loadu_ps(kptr + k), acc); - float sum = _mm256_reduce_add_ps(acc); - for (; k < d; k++) - sum += q[k] * kptr[k]; - s[j] = sum * scale; - } -} - -static inline void decode_pv_gemv_avx(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) -{ - for (int j = 0; j < block_n; j++) - { - __m256 pvec = _mm256_set1_ps(s[j]); - int k = 0; - for (; k + 7 < out_d; k += 8) - { - __m256 oval = _mm256_loadu_ps(out + k); - __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); - _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); - } - for (; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; - } -} - -static void sdpa_decode_avx(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; - __attribute__((aligned(32))) float s[BLOCK_N]; - - vec_zero_avx(out, out_d); - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < n; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, n - n_start); - - decode_qk_dot_avx(s, q, K, n_start, block_n, d, scale); + } - if (mask) + for (; dd < d; dd++) { - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; + float acc = optr[dd]; + for (int j = 0; j < n; j++) + acc += pptr[j] * V[j * d + dd]; + optr[dd] = acc; } + } +} - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); - float new_m = std::max(m, tile_m); - if (m != new_m) +template +static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) +{ + const int VEC_PER_D = D / 8; + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op[M_BLOCK]; + const float* pptr[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_avx(out, scale_factor, out_d); + op[mi] = O + (i + mi) * D; + pptr[mi] = P + (i + mi) * n; } - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) + for (int mi = 0; mi < M_BLOCK; mi++) { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(op[mi] + vi * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); + } + + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(op[mi] + vi * 8, acc[vi]); } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) + } + + for (; i < m; i++) + { + float* optr = O + i * D; + const float* pptr = P + i * n; + + __m256 acc[VEC_PER_D]; + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_loadu_ps(optr + vi * 8); + + for (int j = 0; j < n; j++) { - s[j] = expf(s[j] - new_m); - l_add += s[j]; + __m256 pvec = _mm256_set1_ps(pptr[j]); + for (int vi = 0; vi < VEC_PER_D; vi++) + acc[vi] = _mm256_comp_fmadd_ps(pvec, _mm256_loadu_ps(V + j * D + vi * 8), acc[vi]); } - l += l_add; - decode_pv_gemv_avx(out, s, V, n_start, block_n, out_d); - - m = new_m; + for (int vi = 0; vi < VEC_PER_D; vi++) + _mm256_storeu_ps(optr + vi * 8, acc[vi]); } - - float inv_l = 1.f / l; - vec_scale_avx(out, inv_l, out_d); } + #endif // __AVX__ #if __SSE2__ @@ -3371,197 +2920,29 @@ void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, acc[mi] = _mm_setzero_ps(); for (int k = 0; k < 1024; k += 4) - { - __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) - { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 4; mi++) - S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i + 2 <= m; i += 2) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m128 acc[2][4]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - acc[mi][2] = _mm_setzero_ps(); - acc[mi][3] = _mm_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 4) - { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); - __m128 kv2 = _mm_loadu_ps(k2 + k); - __m128 kv3 = _mm_loadu_ps(k3 + k); - - for (int mi = 0; mi < 2; mi++) - { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); - } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m128 acc[2][2]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 4) - { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); - - for (int mi = 0; mi < 2; mi++) - { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - const float* kptr = K + j * 1024; - - __m128 acc[2]; - for (int mi = 0; mi < 2; mi++) - acc[mi] = _mm_setzero_ps(); - - for (int k = 0; k < 1024; k += 4) - { - __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) - { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 2; mi++) - S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m128 acc0 = _mm_setzero_ps(); - __m128 acc1 = _mm_setzero_ps(); - __m128 acc2 = _mm_setzero_ps(); - __m128 acc3 = _mm_setzero_ps(); - - for (int k = 0; k < 1024; k += 4) - { - __m128 qvec = _mm_loadu_ps(qptr + k); - acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); - acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); - acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); - acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m128 acc0 = _mm_setzero_ps(); - __m128 acc1 = _mm_setzero_ps(); - - for (int k = 0; k < 1024; k += 4) - { - __m128 qvec = _mm_loadu_ps(qptr + k); - acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); - acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); - } - - S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 1024; - const float* kptr = K + j * 1024; - __m128 vacc = _mm_setzero_ps(); - for (int k = 0; k < 1024; k += 4) - vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; + { + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < 4; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } } -} -template<> -void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; for (; i + 2 <= m; i += 2) { int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; __m128 acc[2][4]; for (int mi = 0; mi < 2; mi++) @@ -3572,7 +2953,7 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, acc[mi][3] = _mm_setzero_ps(); } - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) { __m128 kv0 = _mm_loadu_ps(k0 + k); __m128 kv1 = _mm_loadu_ps(k1 + k); @@ -3581,7 +2962,7 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -3600,8 +2981,8 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; __m128 acc[2][2]; for (int mi = 0; mi < 2; mi++) @@ -3610,14 +2991,14 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, acc[mi][1] = _mm_setzero_ps(); } - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) { __m128 kv0 = _mm_loadu_ps(k0 + k); __m128 kv1 = _mm_loadu_ps(k1 + k); for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } @@ -3632,18 +3013,18 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 2048; + const float* kptr = K + j * 1024; __m128 acc[2]; for (int mi = 0; mi < 2; mi++) acc[mi] = _mm_setzero_ps(); - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) { __m128 kvec = _mm_loadu_ps(kptr + k); for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 2048 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } } @@ -3658,18 +3039,18 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; - const float* k2 = K + (j + 2) * 2048; - const float* k3 = K + (j + 3) * 2048; + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; + const float* k2 = K + (j + 2) * 1024; + const float* k3 = K + (j + 3) * 1024; __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); __m128 acc2 = _mm_setzero_ps(); __m128 acc3 = _mm_setzero_ps(); - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) { __m128 qvec = _mm_loadu_ps(qptr + k); acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); @@ -3686,14 +3067,14 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 2048; - const float* k0 = K + (j + 0) * 2048; - const float* k1 = K + (j + 1) * 2048; + const float* qptr = Q + i * 1024; + const float* k0 = K + (j + 0) * 1024; + const float* k1 = K + (j + 1) * 1024; __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) { __m128 qvec = _mm_loadu_ps(qptr + k); acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); @@ -3706,18 +3087,18 @@ void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 2048; - const float* kptr = K + j * 2048; + const float* qptr = Q + i * 1024; + const float* kptr = K + j * 1024; __m128 vacc = _mm_setzero_ps(); - for (int k = 0; k < 2048; k += 4) + for (int k = 0; k < 1024; k += 4) vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } } } -template<> -void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, +template +static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; @@ -3726,10 +3107,10 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m128 acc[2][4]; for (int mi = 0; mi < 2; mi++) @@ -3740,7 +3121,7 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, acc[mi][3] = _mm_setzero_ps(); } - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) { __m128 kv0 = _mm_loadu_ps(k0 + k); __m128 kv1 = _mm_loadu_ps(k1 + k); @@ -3749,7 +3130,7 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); @@ -3768,8 +3149,8 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; __m128 acc[2][2]; for (int mi = 0; mi < 2; mi++) @@ -3778,14 +3159,14 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, acc[mi][1] = _mm_setzero_ps(); } - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) { __m128 kv0 = _mm_loadu_ps(k0 + k); __m128 kv1 = _mm_loadu_ps(k1 + k); for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); } @@ -3800,18 +3181,18 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* kptr = K + j * 4096; + const float* kptr = K + j * D; __m128 acc[2]; for (int mi = 0; mi < 2; mi++) acc[mi] = _mm_setzero_ps(); - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) { __m128 kvec = _mm_loadu_ps(kptr + k); for (int mi = 0; mi < 2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 4096 + k); + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } } @@ -3826,18 +3207,18 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, int j = 0; for (; j + 4 <= n; j += 4) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; - const float* k2 = K + (j + 2) * 4096; - const float* k3 = K + (j + 3) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); __m128 acc2 = _mm_setzero_ps(); __m128 acc3 = _mm_setzero_ps(); - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) { __m128 qvec = _mm_loadu_ps(qptr + k); acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); @@ -3854,14 +3235,14 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, for (; j + 2 <= n; j += 2) { - const float* qptr = Q + i * 4096; - const float* k0 = K + (j + 0) * 4096; - const float* k1 = K + (j + 1) * 4096; + const float* qptr = Q + i * D; + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; __m128 acc0 = _mm_setzero_ps(); __m128 acc1 = _mm_setzero_ps(); - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) { __m128 qvec = _mm_loadu_ps(qptr + k); acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); @@ -3874,10 +3255,10 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, for (; j < n; j++) { - const float* qptr = Q + i * 4096; - const float* kptr = K + j * 4096; + const float* qptr = Q + i * D; + const float* kptr = K + j * D; __m128 vacc = _mm_setzero_ps(); - for (int k = 0; k < 4096; k += 4) + for (int k = 0; k < D; k += 4) vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } @@ -3885,6 +3266,22 @@ void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, } +template<> +void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<2048>(S, Q, K, m, n, scale); +} + + +template<> +void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<4096>(S, Q, K, m, n, scale); +} + + template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) { @@ -4053,168 +3450,6 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, } -static inline void softmax_tile_sse2(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - for (int i = 0; i < m; i++) - { - const float* sptr = S + i * n; - float* pptr = P + i * n; - - __m128 vmax = _mm_set1_ps(m_vec[i]); - int j = 0; - for (; j + 3 < n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(sptr + j)); - float m_new = _mm_reduce_max_ps(vmax); - for (; j < n; j++) - m_new = std::max(m_new, sptr[j]); - - float scale_factor = expf(m_vec[i] - m_new); - scale_out[i] = scale_factor; - l_vec[i] *= scale_factor; - - __m128 vm_new = _mm_set1_ps(m_new); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < n; j += 4) - { - __m128 svec = _mm_loadu_ps(sptr + j); - __m128 evec = exp_ps(_mm_sub_ps(svec, vm_new)); - _mm_storeu_ps(pptr + j, evec); - vsum = _mm_add_ps(vsum, evec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < n; j++) - { - pptr[j] = expf(sptr[j] - m_new); - l_add += pptr[j]; - } - l_vec[i] += l_add; - m_vec[i] = m_new; - } -} - -static inline void vec_scale_sse2(float* x, float s, int n) -{ - __m128 vscale = _mm_set1_ps(s); - int i = 0; - for (; i + 3 < n; i += 4) - _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale)); - for (; i < n; i++) - x[i] *= s; -} - -static inline void vec_zero_sse2(float* x, int n) -{ - __m128 zero = _mm_setzero_ps(); - int i = 0; - for (; i + 3 < n; i += 4) - _mm_storeu_ps(x + i, zero); - for (; i < n; i++) - x[i] = 0.f; -} - -static inline void decode_qk_dot_sse2(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) -{ - for (int j = 0; j < block_n; j++) - { - __m128 acc = _mm_setzero_ps(); - int k = 0; - for (; k + 3 < d; k += 4) - acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), _mm_loadu_ps(K + (n_start + j) * d + k), acc); - float sum = _mm_reduce_add_ps(acc); - for (; k < d; k++) - sum += q[k] * K[(n_start + j) * d + k]; - s[j] = sum * scale; - } -} - -static inline void decode_pv_gemv_sse2(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) -{ - for (int j = 0; j < block_n; j++) - { - __m128 pvec = _mm_set1_ps(s[j]); - int k = 0; - for (; k + 3 < out_d; k += 4) - { - __m128 oval = _mm_loadu_ps(out + k); - __m128 vval = _mm_loadu_ps(V + (n_start + j) * out_d + k); - _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec, vval, oval)); - } - for (; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; - } -} - -static void sdpa_decode_sse2(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; - __attribute__((aligned(16))) float s[BLOCK_N]; - - vec_zero_sse2(out, out_d); - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < n; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, n - n_start); - - decode_qk_dot_sse2(s, q, K, n_start, block_n, d, scale); - - if (mask) - { - int j = 0; - for (; j + 3 < block_n; j += 4) - _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; - } - - __m128 vmax = _mm_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); - float tile_m = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); - - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale_sse2(out, scale_factor, out_d); - } - - __m128 vm_new = _mm_set1_ps(new_m); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); - _mm_storeu_ps(s + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; - - decode_pv_gemv_sse2(out, s, V, n_start, block_n, out_d); - - m = new_m; - } - - float inv_l = 1.f / l; - vec_scale_sse2(out, inv_l, out_d); -} - #endif // __SSE2__ @@ -4484,57 +3719,25 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, static inline void vec_scale_dispatch(float* x, float s, int n) { -#if __AVX512F__ - vec_scale_avx512(x, s, n); -#elif __AVX__ - vec_scale_avx(x, s, n); -#elif __SSE2__ - vec_scale_sse2(x, s, n); -#else - vec_scale_scalar(x, s, n); -#endif + vec_scale(x, s, n); } static inline void vec_zero_dispatch(float* x, int n) { -#if __AVX512F__ - vec_zero_avx512(x, n); -#elif __AVX__ - vec_zero_avx(x, n); -#elif __SSE2__ - vec_zero_sse2(x, n); -#else - vec_zero_scalar(x, n); -#endif + vec_zero(x, n); } static inline void softmax_tile_dispatch(float* P, const float* S, float* m_vec, float* l_vec, float* scale_out, int m, int n) { -#if __AVX512F__ - softmax_tile_avx512(P, S, m_vec, l_vec, scale_out, m, n); -#elif __AVX__ - softmax_tile_avx(P, S, m_vec, l_vec, scale_out, m, n); -#elif __SSE2__ - softmax_tile_sse2(P, S, m_vec, l_vec, scale_out, m, n); -#else - softmax_tile_scalar(P, S, m_vec, l_vec, scale_out, m, n); -#endif + softmax_tile(P, S, m_vec, l_vec, scale_out, m, n); } static inline void sdpa_decode_dispatch(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) { -#if __AVX512F__ - sdpa_decode_avx512(out, q, K, V, mask, n, d, out_d, scale); -#elif __AVX__ - sdpa_decode_avx(out, q, K, V, mask, n, d, out_d, scale); -#elif __SSE2__ - sdpa_decode_sse2(out, q, K, V, mask, n, d, out_d, scale); -#else - sdpa_decode_scalar(out, q, K, V, mask, n, d, out_d, scale); -#endif + sdpa_decode(out, q, K, V, mask, n, d, out_d, scale); } // Timing instrumentation removed From 66c0b68ac98d19424076e570df9a70de8e5d04a2 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Tue, 28 Apr 2026 10:21:00 +0800 Subject: [PATCH 17/53] slim sdpa --- src/layer/x86/sdpa_x86.cpp | 690 ++++++++++--------------------------- 1 file changed, 190 insertions(+), 500 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 593d8193431d..662854701892 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -2089,280 +2089,115 @@ static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float } } -template<> -void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, +template +static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; - for (; i + 4 <= m; i += 4) + if (M_BLOCK2 > 0) { - int j = 0; - for (; j + 4 <= n; j += 4) + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m256 acc[4][4]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - acc[mi][2] = _mm256_setzero_ps(); - acc[mi][3] = _mm256_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 8) + int j = 0; + for (; j + 4 <= n; j += 4) { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - __m256 kv2 = _mm256_loadu_ps(k2 + k); - __m256 kv3 = _mm256_loadu_ps(k3 + k); + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; - for (int mi = 0; mi < 4; mi++) + __m256 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + acc[mi][2] = _mm256_setzero_ps(); + acc[mi][3] = _mm256_setzero_ps(); } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - __m256 acc[4][2]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - - for (int mi = 0; mi < 4; mi++) + for (int k = 0; k < D; k += 8) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - } - } + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); + __m256 kv2 = _mm256_loadu_ps(k2 + k); + __m256 kv3 = _mm256_loadu_ps(k3 + k); - for (; j < n; j++) - { - const float* kptr = K + j * 1024; - - __m256 acc[4]; - for (int mi = 0; mi < 4; mi++) - acc[mi] = _mm256_setzero_ps(); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } - for (int k = 0; k < 1024; k += 8) - { - __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; } } - for (int mi = 0; mi < 4; mi++) - S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i + 2 <= m; i += 2) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m256 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + for (; j + 2 <= n; j += 2) { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - acc[mi][2] = _mm256_setzero_ps(); - acc[mi][3] = _mm256_setzero_ps(); - } + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; - for (int k = 0; k < 1024; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); - __m256 kv2 = _mm256_loadu_ps(k2 + k); - __m256 kv3 = _mm256_loadu_ps(k3 + k); - - for (int mi = 0; mi < 2; mi++) + __m256 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm256_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm256_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); } - } - - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm256_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm256_reduce_add_ps(acc[mi][3]) * scale; - } - } - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m256 acc[2][2]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm256_setzero_ps(); - acc[mi][1] = _mm256_setzero_ps(); - } + for (int k = 0; k < D; k += 8) + { + __m256 kv0 = _mm256_loadu_ps(k0 + k); + __m256 kv1 = _mm256_loadu_ps(k1 + k); - for (int k = 0; k < 1024; k += 8) - { - __m256 kv0 = _mm256_loadu_ps(k0 + k); - __m256 kv1 = _mm256_loadu_ps(k1 + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; } } - for (int mi = 0; mi < 2; mi++) + for (; j < n; j++) { - S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; - } - } + const float* kptr = K + j * D; - for (; j < n; j++) - { - const float* kptr = K + j * 1024; - - __m256 acc[2]; - for (int mi = 0; mi < 2; mi++) - acc[mi] = _mm256_setzero_ps(); + __m256 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm256_setzero_ps(); - for (int k = 0; k < 1024; k += 8) - { - __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int k = 0; k < D; k += 8) { - __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + __m256 kvec = _mm256_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } } - } - - for (int mi = 0; mi < 2; mi++) - S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - __m256 acc2 = _mm256_setzero_ps(); - __m256 acc3 = _mm256_setzero_ps(); - - for (int k = 0; k < 1024; k += 8) - { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); - acc2 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k2 + k), acc2); - acc3 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm256_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm256_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m256 acc0 = _mm256_setzero_ps(); - __m256 acc1 = _mm256_setzero_ps(); - for (int k = 0; k < 1024; k += 8) - { - __m256 qvec = _mm256_loadu_ps(qptr + k); - acc0 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k0 + k), acc0); - acc1 = _mm256_comp_fmadd_ps(qvec, _mm256_loadu_ps(k1 + k), acc1); + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; } - - S[i * n + j + 0] = _mm256_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm256_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 1024; - const float* kptr = K + j * 1024; - __m256 vacc = _mm256_setzero_ps(); - for (int k = 0; k < 1024; k += 8) - vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), _mm256_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm256_reduce_add_ps(vacc) * scale; } } -} -template -static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 2 <= m; i += 2) + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) { int j = 0; for (; j + 4 <= n; j += 4) @@ -2372,8 +2207,8 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const const float* k2 = K + (j + 2) * D; const float* k3 = K + (j + 3) * D; - __m256 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + __m256 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm256_setzero_ps(); acc[mi][1] = _mm256_setzero_ps(); @@ -2388,7 +2223,7 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const __m256 kv2 = _mm256_loadu_ps(k2 + k); __m256 kv3 = _mm256_loadu_ps(k3 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); @@ -2398,7 +2233,7 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; @@ -2412,8 +2247,8 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; - __m256 acc[2][2]; - for (int mi = 0; mi < 2; mi++) + __m256 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm256_setzero_ps(); acc[mi][1] = _mm256_setzero_ps(); @@ -2424,7 +2259,7 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const __m256 kv0 = _mm256_loadu_ps(k0 + k); __m256 kv1 = _mm256_loadu_ps(k1 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); @@ -2432,7 +2267,7 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm256_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm256_reduce_add_ps(acc[mi][1]) * scale; @@ -2443,21 +2278,21 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const { const float* kptr = K + j * D; - __m256 acc[2]; - for (int mi = 0; mi < 2; mi++) + __m256 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) acc[mi] = _mm256_setzero_ps(); for (int k = 0; k < D; k += 8) { __m256 kvec = _mm256_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) S[(i + mi) * n + j] = _mm256_reduce_add_ps(acc[mi]) * scale; } } @@ -2526,11 +2361,19 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const } +template<> +void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<1024, 2, 4>(S, Q, K, m, n, scale); +} + + template<> void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx<2048>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx<2048, 2, 0>(S, Q, K, m, n, scale); } @@ -2538,10 +2381,12 @@ template<> void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx<4096>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx<4096, 2, 0>(S, Q, K, m, n, scale); } + + template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) { @@ -2829,280 +2674,115 @@ static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const floa } } -template<> -void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, +template +static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, const float* K, int m, int n, float scale) { int i = 0; - for (; i + 4 <= m; i += 4) + if (M_BLOCK2 > 0) { - int j = 0; - for (; j + 4 <= n; j += 4) + for (; i + M_BLOCK2 <= m; i += M_BLOCK2) { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m128 acc[4][4]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - acc[mi][2] = _mm_setzero_ps(); - acc[mi][3] = _mm_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 4) + int j = 0; + for (; j + 4 <= n; j += 4) { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); - __m128 kv2 = _mm_loadu_ps(k2 + k); - __m128 kv3 = _mm_loadu_ps(k3 + k); + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; - for (int mi = 0; mi < 4; mi++) + __m128 acc[M_BLOCK2][4]; + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); + acc[mi][2] = _mm_setzero_ps(); + acc[mi][3] = _mm_setzero_ps(); } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; - } - } - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m128 acc[4][2]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - } - - for (int k = 0; k < 1024; k += 4) - { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); - - for (int mi = 0; mi < 4; mi++) + for (int k = 0; k < D; k += 4) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - const float* kptr = K + j * 1024; + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); + __m128 kv2 = _mm_loadu_ps(k2 + k); + __m128 kv3 = _mm_loadu_ps(k3 + k); - __m128 acc[4]; - for (int mi = 0; mi < 4; mi++) - acc[mi] = _mm_setzero_ps(); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } - for (int k = 0; k < 1024; k += 4) - { - __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; } } - for (int mi = 0; mi < 4; mi++) - S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i + 2 <= m; i += 2) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m128 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + for (; j + 2 <= n; j += 2) { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - acc[mi][2] = _mm_setzero_ps(); - acc[mi][3] = _mm_setzero_ps(); - } + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; - for (int k = 0; k < 1024; k += 4) - { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); - __m128 kv2 = _mm_loadu_ps(k2 + k); - __m128 kv3 = _mm_loadu_ps(k3 + k); - - for (int mi = 0; mi < 2; mi++) + __m128 acc[M_BLOCK2][2]; + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm_comp_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm_comp_fmadd_ps(qvec, kv3, acc[mi][3]); + acc[mi][0] = _mm_setzero_ps(); + acc[mi][1] = _mm_setzero_ps(); } - } - for (int mi = 0; mi < 2; mi++) - { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m128 acc[2][2]; - for (int mi = 0; mi < 2; mi++) - { - acc[mi][0] = _mm_setzero_ps(); - acc[mi][1] = _mm_setzero_ps(); - } + for (int k = 0; k < D; k += 4) + { + __m128 kv0 = _mm_loadu_ps(k0 + k); + __m128 kv1 = _mm_loadu_ps(k1 + k); - for (int k = 0; k < 1024; k += 4) - { - __m128 kv0 = _mm_loadu_ps(k0 + k); - __m128 kv1 = _mm_loadu_ps(k1 + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK2; mi++) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; } } - for (int mi = 0; mi < 2; mi++) + for (; j < n; j++) { - S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; - } - } + const float* kptr = K + j * D; - for (; j < n; j++) - { - const float* kptr = K + j * 1024; - - __m128 acc[2]; - for (int mi = 0; mi < 2; mi++) - acc[mi] = _mm_setzero_ps(); + __m128 acc[M_BLOCK2]; + for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm_setzero_ps(); - for (int k = 0; k < 1024; k += 4) - { - __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int k = 0; k < D; k += 4) { - __m128 qvec = _mm_loadu_ps(Q + (i + mi) * 1024 + k); - acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + __m128 kvec = _mm_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); + } } - } - - for (int mi = 0; mi < 2; mi++) - S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; - } - } - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - const float* k2 = K + (j + 2) * 1024; - const float* k3 = K + (j + 3) * 1024; - - __m128 acc0 = _mm_setzero_ps(); - __m128 acc1 = _mm_setzero_ps(); - __m128 acc2 = _mm_setzero_ps(); - __m128 acc3 = _mm_setzero_ps(); - - for (int k = 0; k < 1024; k += 4) - { - __m128 qvec = _mm_loadu_ps(qptr + k); - acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); - acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); - acc2 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k2 + k), acc2); - acc3 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 1024; - const float* k0 = K + (j + 0) * 1024; - const float* k1 = K + (j + 1) * 1024; - - __m128 acc0 = _mm_setzero_ps(); - __m128 acc1 = _mm_setzero_ps(); - - for (int k = 0; k < 1024; k += 4) - { - __m128 qvec = _mm_loadu_ps(qptr + k); - acc0 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k0 + k), acc0); - acc1 = _mm_comp_fmadd_ps(qvec, _mm_loadu_ps(k1 + k), acc1); + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } - - S[i * n + j + 0] = _mm_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 1024; - const float* kptr = K + j * 1024; - __m128 vacc = _mm_setzero_ps(); - for (int k = 0; k < 1024; k += 4) - vacc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), _mm_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm_reduce_add_ps(vacc) * scale; } } -} -template -static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 2 <= m; i += 2) + for (; i + M_BLOCK1 <= m; i += M_BLOCK1) { int j = 0; for (; j + 4 <= n; j += 4) @@ -3112,8 +2792,8 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons const float* k2 = K + (j + 2) * D; const float* k3 = K + (j + 3) * D; - __m128 acc[2][4]; - for (int mi = 0; mi < 2; mi++) + __m128 acc[M_BLOCK1][4]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm_setzero_ps(); acc[mi][1] = _mm_setzero_ps(); @@ -3128,7 +2808,7 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons __m128 kv2 = _mm_loadu_ps(k2 + k); __m128 kv3 = _mm_loadu_ps(k3 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); @@ -3138,7 +2818,7 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; @@ -3152,8 +2832,8 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; - __m128 acc[2][2]; - for (int mi = 0; mi < 2; mi++) + __m128 acc[M_BLOCK1][2]; + for (int mi = 0; mi < M_BLOCK1; mi++) { acc[mi][0] = _mm_setzero_ps(); acc[mi][1] = _mm_setzero_ps(); @@ -3164,7 +2844,7 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons __m128 kv0 = _mm_loadu_ps(k0 + k); __m128 kv1 = _mm_loadu_ps(k1 + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi][0] = _mm_comp_fmadd_ps(qvec, kv0, acc[mi][0]); @@ -3172,7 +2852,7 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { S[(i + mi) * n + j + 0] = _mm_reduce_add_ps(acc[mi][0]) * scale; S[(i + mi) * n + j + 1] = _mm_reduce_add_ps(acc[mi][1]) * scale; @@ -3183,21 +2863,21 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons { const float* kptr = K + j * D; - __m128 acc[2]; - for (int mi = 0; mi < 2; mi++) + __m128 acc[M_BLOCK1]; + for (int mi = 0; mi < M_BLOCK1; mi++) acc[mi] = _mm_setzero_ps(); for (int k = 0; k < D; k += 4) { __m128 kvec = _mm_loadu_ps(kptr + k); - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) { __m128 qvec = _mm_loadu_ps(Q + (i + mi) * D + k); acc[mi] = _mm_comp_fmadd_ps(qvec, kvec, acc[mi]); } } - for (int mi = 0; mi < 2; mi++) + for (int mi = 0; mi < M_BLOCK1; mi++) S[(i + mi) * n + j] = _mm_reduce_add_ps(acc[mi]) * scale; } } @@ -3266,11 +2946,19 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons } +template<> +void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<1024, 2, 4>(S, Q, K, m, n, scale); +} + + template<> void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_sse2<2048>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_sse2<2048, 2, 0>(S, Q, K, m, n, scale); } @@ -3278,10 +2966,12 @@ template<> void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_sse2<4096>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_sse2<4096, 2, 0>(S, Q, K, m, n, scale); } + + template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) { From 1f1b8a71ab113b83304e1952aa1f06cfb2e2f78c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Tue, 28 Apr 2026 12:58:53 +0800 Subject: [PATCH 18/53] simd vec --- src/layer/x86/sdpa_x86.cpp | 87 +++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 48 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 662854701892..a9c61874e031 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -102,65 +102,59 @@ static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, static inline void vec_scale(float* x, float s, int n) { #if __AVX512F__ - __m512 vscale = _mm512_set1_ps(s); + __m512 vscale512 = _mm512_set1_ps(s); int i = 0; for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale)); + _mm512_storeu_ps(x + i, _mm512_mul_ps(_mm512_loadu_ps(x + i), vscale512)); if (i < n) { __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale)); + _mm512_mask_storeu_ps(x + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, x + i), vscale512)); } -#elif __AVX__ - __m256 vscale = _mm256_set1_ps(s); +#else int i = 0; +#if __SSE2__ +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(s); for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale)); - for (; i < n; i++) - x[i] *= s; -#elif __SSE2__ - __m128 vscale = _mm_set1_ps(s); - int i = 0; + _mm256_storeu_ps(x + i, _mm256_mul_ps(_mm256_loadu_ps(x + i), vscale256)); +#endif // __AVX__ + __m128 vscale128 = _mm_set1_ps(s); for (; i + 3 < n; i += 4) - _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale)); + _mm_storeu_ps(x + i, _mm_mul_ps(_mm_loadu_ps(x + i), vscale128)); +#endif // __SSE2__ for (; i < n; i++) x[i] *= s; -#else - for (int i = 0; i < n; i++) - x[i] *= s; -#endif +#endif // __AVX512F__ } static inline void vec_zero(float* x, int n) { #if __AVX512F__ - __m512 zero = _mm512_setzero_ps(); + __m512 zero512 = _mm512_setzero_ps(); int i = 0; for (; i + 15 < n; i += 16) - _mm512_storeu_ps(x + i, zero); + _mm512_storeu_ps(x + i, zero512); if (i < n) { __mmask16 mask = (__mmask16)((1u << (n - i)) - 1); - _mm512_mask_storeu_ps(x + i, mask, zero); + _mm512_mask_storeu_ps(x + i, mask, zero512); } -#elif __AVX__ - __m256 zero = _mm256_setzero_ps(); +#else int i = 0; +#if __SSE2__ +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); for (; i + 7 < n; i += 8) - _mm256_storeu_ps(x + i, zero); - for (; i < n; i++) - x[i] = 0.f; -#elif __SSE2__ - __m128 zero = _mm_setzero_ps(); - int i = 0; + _mm256_storeu_ps(x + i, zero256); +#endif // __AVX__ + __m128 zero128 = _mm_setzero_ps(); for (; i + 3 < n; i += 4) - _mm_storeu_ps(x + i, zero); + _mm_storeu_ps(x + i, zero128); +#endif // __SSE2__ for (; i < n; i++) x[i] = 0.f; -#else - for (int i = 0; i < n; i++) - x[i] = 0.f; -#endif +#endif // __AVX512F__ } static inline void softmax_tile(float* P, const float* S, @@ -312,47 +306,44 @@ static inline void decode_pv_gemv(float* out, const float* s, const float* V, in for (int j = 0; j < block_n; j++) { #if __AVX512F__ - __m512 pvec = _mm512_set1_ps(s[j]); int k = 0; + __m512 pvec512 = _mm512_set1_ps(s[j]); for (; k + 15 < out_d; k += 16) { __m512 oval = _mm512_loadu_ps(out + k); __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); - _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vval, oval)); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec512, vval, oval)); } if (k < out_d) { __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); - _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec, vval, oval)); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); } -#elif __AVX__ - __m256 pvec = _mm256_set1_ps(s[j]); +#else int k = 0; +#if __SSE2__ +#if __AVX__ + __m256 pvec256 = _mm256_set1_ps(s[j]); for (; k + 7 < out_d; k += 8) { __m256 oval = _mm256_loadu_ps(out + k); __m256 vval = _mm256_loadu_ps(V + (n_start + j) * out_d + k); - _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec, vval, oval)); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec256, vval, oval)); } - for (; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; -#elif __SSE2__ - __m128 pvec = _mm_set1_ps(s[j]); - int k = 0; +#endif // __AVX__ + __m128 pvec128 = _mm_set1_ps(s[j]); for (; k + 3 < out_d; k += 4) { __m128 oval = _mm_loadu_ps(out + k); __m128 vval = _mm_loadu_ps(V + (n_start + j) * out_d + k); - _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec, vval, oval)); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec128, vval, oval)); } +#endif // __SSE2__ for (; k < out_d; k++) out[k] += s[j] * V[(n_start + j) * out_d + k]; -#else - for (int k = 0; k < out_d; k++) - out[k] += s[j] * V[(n_start + j) * out_d + k]; -#endif +#endif // __AVX512F__ } } From c7c7d1da4988ab69945cc1cc0da15935170e0af1 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Tue, 28 Apr 2026 19:47:18 +0800 Subject: [PATCH 19/53] remove unless specialization --- src/layer/x86/sdpa_x86.cpp | 259 +++++-------------------------------- 1 file changed, 29 insertions(+), 230 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index a9c61874e031..e79d1652e127 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1036,231 +1036,6 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl } } -// Explicit specialization for D=128: 6x4 kernel to improve K-tile reuse -template<> -void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, - int m, int n, float scale) -{ - int i = 0; - for (; i + 6 <= m; i += 6) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - const float* k2 = K + (j + 2) * 128; - const float* k3 = K + (j + 3) * 128; - - __m512 acc[6][4]; - for (int mi = 0; mi < 6; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - acc[mi][2] = _mm512_setzero_ps(); - acc[mi][3] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - __m512 kv2 = _mm512_loadu_ps(k2 + k); - __m512 kv3 = _mm512_loadu_ps(k3 + k); - - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); - } - } - - for (int mi = 0; mi < 6; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; - } - } - - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc[6][2]; - for (int mi = 0; mi < 6; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 6; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - const float* kptr = K + j * 128; - - __m512 acc[6]; - for (int mi = 0; mi < 6; mi++) - acc[mi] = _mm512_setzero_ps(); - - for (int k = 0; k < 128; k += 16) - { - __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 6; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 6; mi++) - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i + 4 <= m; i += 4) - { - int j = 0; - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc[4][2]; - for (int mi = 0; mi < 4; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - } - - for (int k = 0; k < 128; k += 16) - { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - - for (int mi = 0; mi < 4; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - } - } - - for (int mi = 0; mi < 4; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - } - } - - for (; j < n; j++) - { - __m512 acc[4]; - for (int mi = 0; mi < 4; mi++) - acc[mi] = _mm512_setzero_ps(); - - const float* kptr = K + j * 128; - for (int k = 0; k < 128; k += 16) - { - __m512 kvec = _mm512_loadu_ps(kptr + k); - for (int mi = 0; mi < 4; mi++) - { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * 128 + k); - acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); - } - } - - for (int mi = 0; mi < 4; mi++) - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; - } - } - - for (; i < m; i++) - { - int j = 0; - for (; j + 4 <= n; j += 4) - { - const float* qptr = Q + i * 128; - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - const float* k2 = K + (j + 2) * 128; - const float* k3 = K + (j + 3) * 128; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - __m512 acc2 = _mm512_setzero_ps(); - __m512 acc3 = _mm512_setzero_ps(); - - for (int k = 0; k < 128; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; - S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; - } - - for (; j + 2 <= n; j += 2) - { - const float* qptr = Q + i * 128; - const float* k0 = K + (j + 0) * 128; - const float* k1 = K + (j + 1) * 128; - - __m512 acc0 = _mm512_setzero_ps(); - __m512 acc1 = _mm512_setzero_ps(); - - for (int k = 0; k < 128; k += 16) - { - __m512 qvec = _mm512_loadu_ps(qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k1 + k), acc1); - } - - S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; - S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; - } - - for (; j < n; j++) - { - const float* qptr = Q + i * 128; - const float* kptr = K + j * 128; - __m512 vacc = _mm512_setzero_ps(); - for (int k = 0; k < 128; k += 16) - vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; - } - } -} template @@ -1535,6 +1310,18 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co } +template<> +void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<128, 4, 6>(S, Q, K, m, n, scale); +} +template<> +void qk_gemm_specialized_avx512<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<512, 4, 8>(S, Q, K, m, n, scale); +} template<> void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -1546,7 +1333,7 @@ template<> void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<1024, 2, 4>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx512<1024, 4, 4>(S, Q, K, m, n, scale); } @@ -1555,7 +1342,7 @@ template<> void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<4096, 2, 0>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); } @@ -2352,11 +2139,17 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const } +template<> +void qk_gemm_specialized_avx<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx<512, 2, 4>(S, Q, K, m, n, scale); +} template<> void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx<1024, 2, 4>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx<1024, 4, 2>(S, Q, K, m, n, scale); } @@ -2372,7 +2165,7 @@ template<> void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx<4096, 2, 0>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx<4096, 2, 2>(S, Q, K, m, n, scale); } @@ -2937,6 +2730,12 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons } +template<> +void qk_gemm_specialized_sse2<512>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_sse2<512, 2, 4>(S, Q, K, m, n, scale); +} template<> void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -2957,7 +2756,7 @@ template<> void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_sse2<4096, 2, 0>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_sse2<4096, 2, 2>(S, Q, K, m, n, scale); } From 274f6d5cf0af996381f5ff4b310d080b5f6ef632 Mon Sep 17 00:00:00 2001 From: futz12 <56149058+futz12@users.noreply.github.com> Date: Tue, 28 Apr 2026 12:28:32 +0000 Subject: [PATCH 20/53] apply code-format changes --- src/layer/x86/sdpa_x86.cpp | 226 ++++++++++++++++--------------------- 1 file changed, 98 insertions(+), 128 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e79d1652e127..7703bc230197 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -81,9 +81,8 @@ static void dynamic_quantize_blockwise(const float* src, signed char* dst, float } #endif // NCNN_INT8 - static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { for (int i = 0; i < m; i++) { @@ -158,7 +157,7 @@ static inline void vec_zero(float* x, int n) } static inline void softmax_tile(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) + float* m_vec, float* l_vec, float* scale_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -285,7 +284,7 @@ static inline void softmax_tile(float* P, const float* S, } static inline void pv_gemm_scalar(float* O, const float* P, const float* V, - int m, int n, int d) + int m, int n, int d) { for (int i = 0; i < m; i++) { @@ -468,8 +467,8 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n } static void sdpa_decode(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) { const int BLOCK_N = 128; #if __AVX512F__ @@ -503,7 +502,7 @@ static void sdpa_decode(float* out, const float* q, { __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); } #elif __AVX__ int j = 0; @@ -638,7 +637,7 @@ static void sdpa_decode(float* out, const float* q, #if __AVX512F__ static void qk_gemm_avx512(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { int i = 0; for (; i + 8 <= m; i += 8) @@ -872,7 +871,6 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, } } - template static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -1036,8 +1034,6 @@ static inline void qk_gemm_specialized_avx512(float* S, const float* Q, const fl } } - - template static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -1046,105 +1042,105 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co if (M_BLOCK2 > 0) { for (; i + M_BLOCK2 <= m; i += M_BLOCK2) - { - int j = 0; - for (; j + 4 <= n; j += 4) { - const float* k0 = K + (j + 0) * D; - const float* k1 = K + (j + 1) * D; - const float* k2 = K + (j + 2) * D; - const float* k3 = K + (j + 3) * D; - - __m512 acc[M_BLOCK2][4]; - for (int mi = 0; mi < M_BLOCK2; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); - acc[mi][2] = _mm512_setzero_ps(); - acc[mi][3] = _mm512_setzero_ps(); - } - - for (int k = 0; k < D; k += 16) + int j = 0; + for (; j + 4 <= n; j += 4) { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); - __m512 kv2 = _mm512_loadu_ps(k2 + k); - __m512 kv3 = _mm512_loadu_ps(k3 + k); + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + const float* k2 = K + (j + 2) * D; + const float* k3 = K + (j + 3) * D; + __m512 acc[M_BLOCK2][4]; for (int mi = 0; mi < M_BLOCK2; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); - acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); - acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + acc[mi][2] = _mm512_setzero_ps(); + acc[mi][3] = _mm512_setzero_ps(); } - } - for (int mi = 0; mi < M_BLOCK2; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; - S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; - } - } + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); + __m512 kv2 = _mm512_loadu_ps(k2 + k); + __m512 kv3 = _mm512_loadu_ps(k3 + k); - for (; j + 2 <= n; j += 2) - { - const float* k0 = K + (j + 0) * D; - const float* k1 = K + (j + 1) * D; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(qvec, kv2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(qvec, kv3, acc[mi][3]); + } + } - __m512 acc[M_BLOCK2][2]; - for (int mi = 0; mi < M_BLOCK2; mi++) - { - acc[mi][0] = _mm512_setzero_ps(); - acc[mi][1] = _mm512_setzero_ps(); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + S[(i + mi) * n + j + 2] = _mm512_comp_reduce_add_ps(acc[mi][2]) * scale; + S[(i + mi) * n + j + 3] = _mm512_comp_reduce_add_ps(acc[mi][3]) * scale; + } } - for (int k = 0; k < D; k += 16) + for (; j + 2 <= n; j += 2) { - __m512 kv0 = _mm512_loadu_ps(k0 + k); - __m512 kv1 = _mm512_loadu_ps(k1 + k); + const float* k0 = K + (j + 0) * D; + const float* k1 = K + (j + 1) * D; + __m512 acc[M_BLOCK2][2]; for (int mi = 0; mi < M_BLOCK2; mi++) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); - acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); - acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); } - } - for (int mi = 0; mi < M_BLOCK2; mi++) - { - S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; - S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; - } - } + for (int k = 0; k < D; k += 16) + { + __m512 kv0 = _mm512_loadu_ps(k0 + k); + __m512 kv1 = _mm512_loadu_ps(k1 + k); - for (; j < n; j++) - { - const float* kptr = K + j * D; + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } - __m512 acc[M_BLOCK2]; - for (int mi = 0; mi < M_BLOCK2; mi++) - acc[mi] = _mm512_setzero_ps(); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } - for (int k = 0; k < D; k += 16) + for (; j < n; j++) { - __m512 kvec = _mm512_loadu_ps(kptr + k); + const float* kptr = K + j * D; + + __m512 acc[M_BLOCK2]; for (int mi = 0; mi < M_BLOCK2; mi++) + acc[mi] = _mm512_setzero_ps(); + + for (int k = 0; k < D; k += 16) { - __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); - acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + __m512 kvec = _mm512_loadu_ps(kptr + k); + for (int mi = 0; mi < M_BLOCK2; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * D + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } } - } - for (int mi = 0; mi < M_BLOCK2; mi++) - S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + for (int mi = 0; mi < M_BLOCK2; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } } } - } for (; i + M_BLOCK1 <= m; i += M_BLOCK1) { @@ -1309,43 +1305,38 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co } } - template<> void qk_gemm_specialized_avx512<128>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx512<128, 4, 6>(S, Q, K, m, n, scale); } template<> void qk_gemm_specialized_avx512<512>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx512<512, 4, 8>(S, Q, K, m, n, scale); } template<> void qk_gemm_specialized_avx512<2048>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<2048, 2, 0>(S, Q, K, m, n, scale); } template<> void qk_gemm_specialized_avx512<1024>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx512<1024, 4, 4>(S, Q, K, m, n, scale); } - - template<> void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); } - template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) { @@ -1476,7 +1467,6 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int } } - template static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) { @@ -1633,13 +1623,12 @@ static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int } } - #endif // __AVX512F__ #if __AVX__ static void qk_gemm_avx(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { int i = 0; for (; i + 6 <= m; i += 6) @@ -1768,7 +1757,6 @@ static void qk_gemm_avx(float* S, const float* Q, const float* K, } } - template static inline void qk_gemm_specialized_avx(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -2138,39 +2126,33 @@ static inline void qk_gemm_specialized_tiled_avx(float* S, const float* Q, const } } - template<> void qk_gemm_specialized_avx<512>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx<512, 2, 4>(S, Q, K, m, n, scale); } template<> void qk_gemm_specialized_avx<1024>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx<1024, 4, 2>(S, Q, K, m, n, scale); } - template<> void qk_gemm_specialized_avx<2048>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx<2048, 2, 0>(S, Q, K, m, n, scale); } - template<> void qk_gemm_specialized_avx<4096>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx<4096, 2, 2>(S, Q, K, m, n, scale); } - - - template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n, int d) { @@ -2283,7 +2265,6 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, } } - template static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, int n) { @@ -2338,13 +2319,12 @@ static inline void pv_gemm_avx(float* O, const float* P, const float* V, int m, } } - #endif // __AVX__ #if __SSE2__ static void qk_gemm_sse2(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { int i = 0; for (; i + 4 <= m; i += 4) @@ -2398,7 +2378,6 @@ static void qk_gemm_sse2(float* S, const float* Q, const float* K, } } - template static inline void qk_gemm_specialized_sse2(float* S, const float* Q, const float* K, int m, int n, float scale) @@ -2729,39 +2708,33 @@ static inline void qk_gemm_specialized_tiled_sse2(float* S, const float* Q, cons } } - template<> void qk_gemm_specialized_sse2<512>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_sse2<512, 2, 4>(S, Q, K, m, n, scale); } template<> void qk_gemm_specialized_sse2<1024>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_sse2<1024, 2, 4>(S, Q, K, m, n, scale); } - template<> void qk_gemm_specialized_sse2<2048>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_sse2<2048, 2, 0>(S, Q, K, m, n, scale); } - template<> void qk_gemm_specialized_sse2<4096>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_sse2<4096, 2, 2>(S, Q, K, m, n, scale); } - - - template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n, int d) { @@ -2874,7 +2847,6 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, } } - template static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, int n) { @@ -2929,12 +2901,10 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, } } - #endif // __SSE2__ - static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { if (d == 128) { @@ -3092,7 +3062,7 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, } static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, - int m, int n, int d) + int m, int n, int d) { if (d == 128) { @@ -3214,8 +3184,8 @@ static inline void softmax_tile_dispatch(float* P, const float* S, } static inline void sdpa_decode_dispatch(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) + const float* K, const float* V, const float* mask, + int n, int d, int out_d, float scale) { sdpa_decode(out, q, K, V, mask, n, d, out_d, scale); } From 1473d633b8281053409b27c71bb8a342dc925db8 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 29 Apr 2026 11:14:16 +0800 Subject: [PATCH 21/53] split kv --- src/layer/x86/sdpa_x86.cpp | 287 +++++++++++++++++++++++++++++++++++-- 1 file changed, 277 insertions(+), 10 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 7703bc230197..e1cc4e37e2ff 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -634,6 +634,212 @@ static void sdpa_decode(float* out, const float* q, vec_scale(out, inv_l, out_d); } +static void sdpa_decode_chunk( + float* out, float* m_out, float* l_out, + const float* q, const float* K, const float* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(out, out_d); + + float m = -FLT_MAX; + float l = 0.f; + + for (int n = n_start; n < n_end; n += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n_end - n); + + decode_qk_dot(s, q, K, n, block_n, d, scale); + + if (mask) + { +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n + j))); + for (; j < block_n; j++) + s[j] += mask[n + j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n + j))); + for (; j < block_n; j++) + s[j] += mask[n + j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[n + j]; +#endif + } + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float tile_m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#else + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#endif + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(out, scale_factor, out_d); + } + +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#else + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#endif + + decode_pv_gemv(out, s, V, n, block_n, out_d); + + m = new_m; + } + + *m_out = m; + *l_out = l; +} + +static void sdpa_decode_reduce( + float* out, int out_d, + const float* partials, int num_chunks, int partial_stride) +{ + float M_final = -FLT_MAX; + float S_final = 0.f; + vec_zero(out, out_d); + + for (int c = 0; c < num_chunks; c++) + { + const float* p = partials + c * partial_stride; + float M_chunk = p[0]; + float S_chunk = p[1]; + if (S_chunk == 0.f) continue; + + const float* VKQ_chunk = p + 2; + + float M_new = std::max(M_final, M_chunk); + float scale_final = expf(M_final - M_new); + float scale_chunk = expf(M_chunk - M_new); + + for (int k = 0; k < out_d; k++) + { + out[k] = out[k] * scale_final + VKQ_chunk[k] * scale_chunk; + } + + S_final = S_final * scale_final + S_chunk * scale_chunk; + M_final = M_new; + } + + if (S_final != 0.f) + { + float inv_s = 1.f / S_final; + vec_scale(out, inv_s, out_d); + } +} + #if __AVX512F__ static void qk_gemm_avx512(float* S, const float* Q, const float* K, @@ -3477,21 +3683,31 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // FP32 optimized path using tiled GEMM + online softmax if (src_seqlen == 1) { - // Decode path: fused GEMV kernel for single-query attention - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) + const int BLOCK_N = 128; + const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; + + if (use_split_kv) { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; - for (int hq = 0; hq < num_heads_per_group; hq++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) { - int q = g * num_heads_per_group + hq; + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); const Mat query_head = query.channel(q); - Mat top_blob_head = top_blob.channel(q); const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); const float* Kptr = key_head; const float* Vptr = value_head; @@ -3511,7 +3727,58 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to mask_ptr = mask_head.row(0); } - sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } + } + else + { + // Decode path: fused GEMV kernel for single-query attention + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } } } From cdd69dd69fe42b47c9d8198f9550403d33c4a18c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 29 Apr 2026 15:28:15 +0800 Subject: [PATCH 22/53] prefetch opt for decode --- src/layer/x86/sdpa_x86.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e1cc4e37e2ff..af4870ce9a15 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -305,6 +305,9 @@ static inline void decode_pv_gemv(float* out, const float* s, const float* V, in for (int j = 0; j < block_n; j++) { #if __AVX512F__ + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + int k = 0; __m512 pvec512 = _mm512_set1_ps(s[j]); for (; k + 15 < out_d; k += 16) @@ -321,6 +324,9 @@ static inline void decode_pv_gemv(float* out, const float* s, const float* V, in _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); } #else + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + int k = 0; #if __SSE2__ #if __AVX__ @@ -357,6 +363,14 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n const float* k2 = K + (n_start + j + 2) * d; const float* k3 = K + (n_start + j + 3) * d; + if (j + 7 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 5) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 6) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 7) * d), _MM_HINT_T1); + } + __m512 acc0 = _mm512_setzero_ps(); __m512 acc1 = _mm512_setzero_ps(); __m512 acc2 = _mm512_setzero_ps(); @@ -389,6 +403,9 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n for (; j < block_n; j++) { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + __m512 acc = _mm512_setzero_ps(); int k = 0; for (; k + 15 < d; k += 16) @@ -407,6 +424,12 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n const float* k0 = K + (n_start + j + 0) * d; const float* k1 = K + (n_start + j + 1) * d; + if (j + 5 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 3) * d), _MM_HINT_T1); + } + __m256 acc0 = _mm256_setzero_ps(); __m256 acc1 = _mm256_setzero_ps(); @@ -433,6 +456,9 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n for (; j < block_n; j++) { + if (j + 2 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + const float* kptr = K + (n_start + j) * d; __m256 acc = _mm256_setzero_ps(); int k = 0; @@ -446,6 +472,9 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n #elif __SSE2__ for (int j = 0; j < block_n; j++) { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + __m128 acc = _mm_setzero_ps(); int k = 0; for (; k + 3 < d; k += 4) @@ -458,6 +487,9 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n #else for (int j = 0; j < block_n; j++) { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + float sum = 0.f; for (int k = 0; k < d; k++) sum += q[k] * K[(n_start + j) * d + k]; From 5a0a8792d43d5b69ca5534bc66ba88a70197ee1e Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 29 Apr 2026 18:30:58 +0800 Subject: [PATCH 23/53] opt for gemv --- src/layer/x86/sdpa_x86.cpp | 142 +++++++++++++++++++++++++++++++++++-- 1 file changed, 136 insertions(+), 6 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index af4870ce9a15..140a33df7d52 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -302,19 +302,73 @@ static inline void pv_gemm_scalar(float* O, const float* P, const float* V, static inline void decode_pv_gemv(float* out, const float* s, const float* V, int n_start, int block_n, int out_d) { - for (int j = 0; j < block_n; j++) - { #if __AVX512F__ - if (j + 4 < block_n) - _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + int j = 0; + for (; j + 1 < block_n; j += 2) + { + if (j + 6 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 6) * out_d), _MM_HINT_T1); + __m512 pvec0 = _mm512_set1_ps(s[j]); + __m512 pvec1 = _mm512_set1_ps(s[j + 1]); int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v00 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v01 = _mm512_loadu_ps(V + (n_start + j) * out_d + k + 16); + __m512 v10 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k); + __m512 v11 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k + 16); + oval0 = _mm512_fmadd_ps(pvec0, v00, oval0); + oval1 = _mm512_fmadd_ps(pvec0, v01, oval1); + oval0 = _mm512_fmadd_ps(pvec1, v10, oval0); + oval1 = _mm512_fmadd_ps(pvec1, v11, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 v0 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_loadu_ps(V + (n_start + j + 1) * out_d + k); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_storeu_ps(out + k, oval); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 v0 = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j + 1) * out_d + k); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_mask_storeu_ps(out + k, mask_d, oval); + } + } + for (; j < block_n; j++) + { __m512 pvec512 = _mm512_set1_ps(s[j]); - for (; k + 15 < out_d; k += 16) + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v0 = _mm512_loadu_ps(V + (n_start + j) * out_d + k); + __m512 v1 = _mm512_loadu_ps(V + (n_start + j) * out_d + k + 16); + oval0 = _mm512_fmadd_ps(pvec512, v0, oval0); + oval1 = _mm512_fmadd_ps(pvec512, v1, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) { __m512 oval = _mm512_loadu_ps(out + k); __m512 vval = _mm512_loadu_ps(V + (n_start + j) * out_d + k); _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec512, vval, oval)); + k += 16; } if (k < out_d) { @@ -323,7 +377,11 @@ static inline void decode_pv_gemv(float* out, const float* s, const float* V, in __m512 vval = _mm512_maskz_loadu_ps(mask_d, V + (n_start + j) * out_d + k); _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); } + } #else + int j = 0; + for (; j < block_n; j++) + { if (j + 4 < block_n) _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); @@ -348,14 +406,86 @@ static inline void decode_pv_gemv(float* out, const float* s, const float* V, in #endif // __SSE2__ for (; k < out_d; k++) out[k] += s[j] * V[(n_start + j) * out_d + k]; -#endif // __AVX512F__ } +#endif // __AVX512F__ } static inline void decode_qk_dot(float* s, const float* q, const float* K, int n_start, int block_n, int d, float scale) { #if __AVX512F__ int j = 0; + if (d >= 256) + { + for (; j + 7 < block_n; j += 8) + { + const float* k0 = K + (n_start + j + 0) * d; + const float* k1 = K + (n_start + j + 1) * d; + const float* k2 = K + (n_start + j + 2) * d; + const float* k3 = K + (n_start + j + 3) * d; + const float* k4 = K + (n_start + j + 4) * d; + const float* k5 = K + (n_start + j + 5) * d; + const float* k6 = K + (n_start + j + 6) * d; + const float* k7 = K + (n_start + j + 7) * d; + + if (j + 15 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 8) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 9) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 10) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 11) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 12) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 13) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 14) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 15) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + __m512 acc4 = _mm512_setzero_ps(); + __m512 acc5 = _mm512_setzero_ps(); + __m512 acc6 = _mm512_setzero_ps(); + __m512 acc7 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k3 + k), acc3); + acc4 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k4 + k), acc4); + acc5 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k5 + k), acc5); + acc6 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k6 + k), acc6); + acc7 = _mm512_fmadd_ps(qv, _mm512_loadu_ps(k7 + k), acc7); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + acc0 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k3 + k), acc3); + acc4 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k4 + k), acc4); + acc5 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k5 + k), acc5); + acc6 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k6 + k), acc6); + acc7 = _mm512_fmadd_ps(qv, _mm512_maskz_loadu_ps(mask_d, k7 + k), acc7); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + s[j + 4] = _mm512_comp_reduce_add_ps(acc4) * scale; + s[j + 5] = _mm512_comp_reduce_add_ps(acc5) * scale; + s[j + 6] = _mm512_comp_reduce_add_ps(acc6) * scale; + s[j + 7] = _mm512_comp_reduce_add_ps(acc7) * scale; + } + } + for (; j + 3 < block_n; j += 4) { const float* k0 = K + (n_start + j + 0) * d; From ce8b73f35e6c7d1b48253a34e08861e895a9e9cb Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 29 Apr 2026 23:50:56 +0800 Subject: [PATCH 24/53] support perf int8 --- tests/perf/perf_sdpa_decode.cpp | 10 +++++++++ tests/perf/perf_sdpa_prefill.cpp | 10 +++++++++ tests/perf/perfutil.cpp | 35 ++++++++++++++++++++++++++++++++ tests/perf/perfutil.h | 6 ++++++ 4 files changed, 61 insertions(+) diff --git a/tests/perf/perf_sdpa_decode.cpp b/tests/perf/perf_sdpa_decode.cpp index 7910c7ef0179..cbde525a48c7 100644 --- a/tests/perf/perf_sdpa_decode.cpp +++ b/tests/perf/perf_sdpa_decode.cpp @@ -29,6 +29,16 @@ static void perf_sdpa_decode(int embed_dim, int num_heads, int num_groups, int p perf_layer("SDPA", pd, weights, inputs, 3, "embed=%d heads=%d groups=%d past=%d", embed_dim, num_heads, num_groups, past_seqlen); + + // int8 variant + ncnn::ParamDict pd_int8; + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 1); // kv_cache = 1 + pd_int8.set(18, 2); // int8_scale_term + perf_layer_int8("SDPA", pd_int8, weights, inputs, 3, + "embed=%d heads=%d groups=%d past=%d", + embed_dim, num_heads, num_groups, past_seqlen); } int main() diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index 31212c248ab1..7717ab9942ac 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -25,6 +25,16 @@ static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int perf_layer("SDPA", pd, weights, inputs, 1, "embed=%d heads=%d groups=%d seqlen=%d", embed_dim, num_heads, num_groups, src_seqlen); + + // int8 variant + ncnn::ParamDict pd_int8; + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 0); // kv_cache = 0 + pd_int8.set(18, 2); // int8_scale_term + perf_layer_int8("SDPA", pd_int8, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); } int main() diff --git a/tests/perf/perfutil.cpp b/tests/perf/perfutil.cpp index 437db1251f2d..36331f937f3d 100644 --- a/tests/perf/perfutil.cpp +++ b/tests/perf/perfutil.cpp @@ -821,3 +821,38 @@ void perf_layer(const char* layer_type, const ncnn::ParamDict& pd, perf_layer_impl(layer_type, pd, weights, inputs, 1, tag); } + +void perf_layer_int8(const char* layer_type, const ncnn::ParamDict& pd, + const std::vector& weights, + const std::vector& inputs, int top_blob_count, + const char* param_fmt, ...) +{ + char tag[256]; + va_list args; + va_start(args, param_fmt); + build_tag(tag, sizeof(tag), layer_type, inputs, param_fmt, args); + va_end(args); + + ncnn::Option opt; + opt.lightmode = true; + opt.num_threads = 1; + opt.use_packing_layout = true; + opt.use_fp16_packed = false; + opt.use_fp16_storage = false; + opt.use_fp16_arithmetic = false; + opt.use_bf16_packed = false; + opt.use_bf16_storage = false; + opt.use_vulkan_compute = false; + opt.use_winograd_convolution = true; + opt.use_sgemm_convolution = true; + opt.use_int8_inference = true; + + PerfResult result; + int ret = perf_layer_cpu(layer_type, pd, weights, inputs, top_blob_count, opt, 0, result); + if (ret == 0) + { + char full_tag[512]; + snprintf(full_tag, sizeof(full_tag), "%s int8", tag); + print_perf_result(full_tag, result); + } +} diff --git a/tests/perf/perfutil.h b/tests/perf/perfutil.h index 06b31eeef508..e4ed236de19a 100644 --- a/tests/perf/perfutil.h +++ b/tests/perf/perfutil.h @@ -28,4 +28,10 @@ void perf_layer(const char* layer_type, const ncnn::ParamDict& pd, const std::vector& inputs, int top_blob_count, const char* param_fmt, ...); +// int8-only benchmark (does not test fp16/bf16 variants) +void perf_layer_int8(const char* layer_type, const ncnn::ParamDict& pd, + const std::vector& weights, + const std::vector& inputs, int top_blob_count, + const char* param_fmt, ...); + #endif // PERFUTIL_H From bd7ea9cf2033be2fc39916296f23e90632510517 Mon Sep 17 00:00:00 2001 From: futz12 <56149058+futz12@users.noreply.github.com> Date: Wed, 29 Apr 2026 15:53:16 +0000 Subject: [PATCH 25/53] apply code-format changes --- tests/perf/perf_sdpa_decode.cpp | 8 ++++---- tests/perf/perf_sdpa_prefill.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/perf/perf_sdpa_decode.cpp b/tests/perf/perf_sdpa_decode.cpp index cbde525a48c7..60f39932e421 100644 --- a/tests/perf/perf_sdpa_decode.cpp +++ b/tests/perf/perf_sdpa_decode.cpp @@ -32,10 +32,10 @@ static void perf_sdpa_decode(int embed_dim, int num_heads, int num_groups, int p // int8 variant ncnn::ParamDict pd_int8; - pd_int8.set(5, 0); // attn_mask = 0 - pd_int8.set(6, 0.f); // scale = 0 - pd_int8.set(7, 1); // kv_cache = 1 - pd_int8.set(18, 2); // int8_scale_term + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 1); // kv_cache = 1 + pd_int8.set(18, 2); // int8_scale_term perf_layer_int8("SDPA", pd_int8, weights, inputs, 3, "embed=%d heads=%d groups=%d past=%d", embed_dim, num_heads, num_groups, past_seqlen); diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index 7717ab9942ac..1977fbb9c39f 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -28,10 +28,10 @@ static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int // int8 variant ncnn::ParamDict pd_int8; - pd_int8.set(5, 0); // attn_mask = 0 - pd_int8.set(6, 0.f); // scale = 0 - pd_int8.set(7, 0); // kv_cache = 0 - pd_int8.set(18, 2); // int8_scale_term + pd_int8.set(5, 0); // attn_mask = 0 + pd_int8.set(6, 0.f); // scale = 0 + pd_int8.set(7, 0); // kv_cache = 0 + pd_int8.set(18, 2); // int8_scale_term perf_layer_int8("SDPA", pd_int8, weights, inputs, 1, "embed=%d heads=%d groups=%d seqlen=%d", embed_dim, num_heads, num_groups, src_seqlen); From 69d873bf138255e5005ea6b6b28965b16b1ccb28 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Fri, 1 May 2026 21:45:13 +0800 Subject: [PATCH 26/53] basic int8 --- src/layer/x86/sdpa_x86.cpp | 38 +- src/layer/x86/sdpa_x86_avx2.cpp | 40 + src/layer/x86/sdpa_x86_int8.h | 2502 +++++++++++++++++++++++++++++++ 3 files changed, 2546 insertions(+), 34 deletions(-) create mode 100644 src/layer/x86/sdpa_x86_avx2.cpp create mode 100644 src/layer/x86/sdpa_x86_int8.h diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 140a33df7d52..c13c7ffac909 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -49,37 +49,7 @@ int SDPA_x86::destroy_pipeline(const Option& /*_opt*/) return 0; } -#if NCNN_INT8 -static inline signed char float2int8(float v) -{ - int int32 = static_cast(round(v)); - if (int32 > 127) return 127; - if (int32 < -127) return -127; - return (signed char)int32; -} - -static void dynamic_quantize_blockwise(const float* src, signed char* dst, float* scales, int width) -{ - const int block_size = 32; - int num_blocks = (width + block_size - 1) / block_size; - for (int b = 0; b < num_blocks; b++) - { - int start = b * block_size; - int end = start + block_size < width ? start + block_size : width; - float absmax = 0.f; - for (int i = start; i < end; i++) - { - absmax = std::max(absmax, (float)fabs(src[i])); - } - float scale = absmax == 0.f ? 1.f : 127.f / absmax; - scales[b] = scale; - for (int i = start; i < end; i++) - { - dst[i] = float2int8(src[i] * scale); - } - } -} -#endif // NCNN_INT8 +#include "sdpa_x86_int8.h" static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, int m, int n, int d, float scale) @@ -3665,7 +3635,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat key_scales_head = key_scales.channel(g); for (int j = 0; j < dst_seqlen; j++) { - dynamic_quantize_blockwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + dynamic_quantize_blockwise_dispatch(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); } const Mat value_head = value.channel(g); @@ -3673,7 +3643,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat value_scales_head = value_scales.channel(g); for (int j = 0; j < dst_seqlen; j++) { - dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); + dynamic_quantize_blockwise_dispatch(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); } } @@ -3721,7 +3691,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int i = 0; i < block_m; i++) { - dynamic_quantize_blockwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + dynamic_quantize_blockwise_dispatch(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); } for (int i = 0; i < block_m; i++) diff --git a/src/layer/x86/sdpa_x86_avx2.cpp b/src/layer/x86/sdpa_x86_avx2.cpp new file mode 100644 index 000000000000..dd1277df7fba --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx2.cpp @@ -0,0 +1,40 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVX2__ +// force emit inline function symbols for non-avx2 runtime dispatch +static void __attribute__((used)) sdpa_x86_int8_avx2_dummy() +{ + // call through volatile function pointer to force instantiation + void (* volatile f1)(const float*, signed char*, float*, int) = dynamic_quantize_blockwise_avx2; + void (* volatile f2)(const float*, signed char*, float*, int) = dynamic_quantize_rowwise_avx2; + int (* volatile f3)(const signed char*, const signed char*, int) = qk_int8_dot_block_avx2; + void (* volatile f4)(float*, const signed char*, const signed char*, const float*, const float*, int, int, int, float) = decode_qk_dot_int8_avx2; + void (* volatile f5)(float*, const signed char*, const signed char*, float, const float*, int, int, float) = qk_int8_gemm_row_avx2; + void (* volatile f6)(float*, const signed char*, const signed char*, const float*, const float*, int, int, int, float) = qk_int8_gemm_tiled_avx2; + void (* volatile f7)(float*, const float*, const signed char*, const float*, int, int, int) = decode_pv_gemv_int8_avx2; + void (* volatile f8)(float*, const float*, const signed char*, const float*, int, int) = pv_float_int8_gemm_row_avx2; + void (* volatile f9)(float*, float, const signed char*, int) = pv_float_int8_fma_block_avx2; + void (* volatile f10)(float*, const float*, const signed char*, const float*, int, int, int) = pv_float_int8_gemm_tile_avx2; + (void)f1; (void)f2; (void)f3; (void)f4; (void)f5; + (void)f6; (void)f7; (void)f8; (void)f9; (void)f10; +} +#endif // __AVX2__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_int8.h b/src/layer/x86/sdpa_x86_int8.h new file mode 100644 index 000000000000..1ba9d19311f9 --- /dev/null +++ b/src/layer/x86/sdpa_x86_int8.h @@ -0,0 +1,2502 @@ +static inline signed char float2int8(float v) +{ + int int32 = static_cast(roundf(v)); + if (int32 > 127) return 127; + if (int32 < -127) return -127; + return (signed char)int32; +} + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +void dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width); +void dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width); +int qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len); +void decode_qk_dot_int8_avx2(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avx2(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avx2(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +void decode_pv_gemv_int8_avx2(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d); +void pv_float_int8_gemm_row_avx2(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d); +void pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len); +void pv_float_int8_gemm_tile_avx2(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim); +#endif + +static void dynamic_quantize_blockwise_scalar(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + for (int i = start; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + for (int i = start; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} + +#if __SSE2__ +static inline void dynamic_quantize_blockwise_sse2(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m128 sign_mask = _mm_castsi128_ps(_mm_set1_epi32(1 << 31)); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m128 vmax = _mm_setzero_ps(); + for (; i + 3 < end; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + __m128 ax = _mm_andnot_ps(sign_mask, x); + vmax = _mm_max_ps(vmax, ax); + } + absmax = _mm_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m128 vscale = _mm_set1_ps(scale); + i = start; + for (; i + 15 < end; i += 16) + { + __m128 x0 = _mm_loadu_ps(src + i); + __m128 x1 = _mm_loadu_ps(src + i + 4); + __m128 x2 = _mm_loadu_ps(src + i + 8); + __m128 x3 = _mm_loadu_ps(src + i + 12); + x0 = _mm_mul_ps(x0, vscale); + x1 = _mm_mul_ps(x1, vscale); + x2 = _mm_mul_ps(x2, vscale); + x3 = _mm_mul_ps(x3, vscale); + __m128i v8 = float2int8_sse(x0, x1, x2, x3); + _mm_storeu_si128((__m128i*)(dst + i), v8); + } + for (; i + 3 < end; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + x = _mm_mul_ps(x, vscale); + int32_t v4 = float2int8_sse(x); + *(int32_t*)(dst + i) = v4; + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m256 sign_mask = _mm256_set1_ps(-0.f); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m256 vmax = _mm256_setzero_ps(); + for (; i + 7 < end; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256 ax = _mm256_andnot_ps(sign_mask, x); + vmax = _mm256_max_ps(vmax, ax); + } + absmax = _mm256_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m256 vscale = _mm256_set1_ps(scale); + i = start; + for (; i + 31 < end; i += 32) + { + __m256 x0 = _mm256_loadu_ps(src + i); + __m256 x1 = _mm256_loadu_ps(src + i + 8); + __m256 x2 = _mm256_loadu_ps(src + i + 16); + __m256 x3 = _mm256_loadu_ps(src + i + 24); + __m256i y0 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x0, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y1 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x1, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y2 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x2, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i y3 = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x3, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm256_castsi256_si128(y0); + __m128i a1 = _mm256_extractf128_si256(y0, 1); + __m128i a2 = _mm256_castsi256_si128(y1); + __m128i a3 = _mm256_extractf128_si256(y1, 1); + __m128i a4 = _mm256_castsi256_si128(y2); + __m128i a5 = _mm256_extractf128_si256(y2, 1); + __m128i a6 = _mm256_castsi256_si128(y3); + __m128i a7 = _mm256_extractf128_si256(y3, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i b2 = _mm_packs_epi32(a4, a5); + __m128i b3 = _mm_packs_epi32(a6, a7); + __m128i c0 = _mm_packs_epi16(b0, b1); + __m128i c1 = _mm_packs_epi16(b2, b3); + _mm_storeu_si128((__m128i*)(dst + i), c0); + _mm_storeu_si128((__m128i*)(dst + i + 16), c1); + } + for (; i + 7 < end; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256i y = _mm256_cvtps_epi32(_mm256_round_ps(_mm256_mul_ps(x, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm256_castsi256_si128(y); + __m128i a1 = _mm256_extractf128_si256(y, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i c0 = _mm_packs_epi16(b0, b0); + int64_t v8 = _mm_cvtsi128_si64(c0); + *(int64_t*)(dst + i) = v8; + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void dynamic_quantize_blockwise_avx512(const float* src, signed char* dst, float* scales, int width) +{ + const int block_size = 32; + int num_blocks = (width + block_size - 1) / block_size; + __m512 sign_mask = _mm512_set1_ps(-0.f); + for (int b = 0; b < num_blocks; b++) + { + int start = b * block_size; + int end = start + block_size < width ? start + block_size : width; + float absmax = 0.f; + int i = start; + __m512 vmax = _mm512_setzero_ps(); + for (; i + 15 < end; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512 ax = _mm512_andnot_ps(sign_mask, x); + vmax = _mm512_max_ps(vmax, ax); + } + absmax = _mm512_comp_reduce_max_ps(vmax); + for (; i < end; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float scale = absmax == 0.f ? 1.f : 127.f / absmax; + scales[b] = scale; + __m512 vscale = _mm512_set1_ps(scale); + i = start; + for (; i + 15 < end; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512i y = _mm512_cvtps_epi32(_mm512_roundscale_ps(_mm512_mul_ps(x, vscale), _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i a0 = _mm512_extracti32x4_epi32(y, 0); + __m128i a1 = _mm512_extracti32x4_epi32(y, 1); + __m128i a2 = _mm512_extracti32x4_epi32(y, 2); + __m128i a3 = _mm512_extracti32x4_epi32(y, 3); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i c0 = _mm_packs_epi16(b0, b1); + _mm_storeu_si128((__m128i*)(dst + i), c0); + } + for (; i < end; i++) + { + dst[i] = float2int8(src[i] * scale); + } + } +} +#endif // __AVX512F__ + +static void dynamic_quantize_blockwise_dispatch(const float* src, signed char* dst, float* scales, int width) +{ +#if __AVX512F__ + dynamic_quantize_blockwise_avx512(src, dst, scales, width); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + dynamic_quantize_blockwise_avx2(src, dst, scales, width); + return; + } +#endif +#if __AVX2__ + dynamic_quantize_blockwise_avx2(src, dst, scales, width); +#elif __SSE2__ + dynamic_quantize_blockwise_sse2(src, dst, scales, width); +#else + dynamic_quantize_blockwise_scalar(src, dst, scales, width); +#endif +#endif +} + +static inline void dynamic_quantize_rowwise_scalar(const float* src, signed char* dst, float* scale, int width) +{ + float absmax = 0.f; + for (int i = 0; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + for (int i = 0; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} + +#if __SSE2__ +static inline void dynamic_quantize_rowwise_sse2(const float* src, signed char* dst, float* scale, int width) +{ + __m128 sign_mask = _mm_set1_ps(-0.f); + __m128 vmax = _mm_setzero_ps(); + int i = 0; + for (; i + 3 < width; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + vmax = _mm_max_ps(vmax, _mm_andnot_ps(sign_mask, x)); + } + float absmax = _mm_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m128 vscale = _mm_set1_ps(s); + i = 0; + for (; i + 3 < width; i += 4) + { + __m128 x = _mm_loadu_ps(src + i); + __m128i yi = _mm_cvtps_epi32(_mm_mul_ps(x, vscale)); + (void)yi; + dst[i + 0] = float2int8(src[i + 0] * s); + dst[i + 1] = float2int8(src[i + 1] * s); + dst[i + 2] = float2int8(src[i + 2] * s); + dst[i + 3] = float2int8(src[i + 3] * s); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +#if __AVX2__ +inline void __attribute__((noinline)) dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width) +{ + __m256 sign_mask = _mm256_set1_ps(-0.f); + __m256 vmax = _mm256_setzero_ps(); + int i = 0; + for (; i + 7 < width; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + vmax = _mm256_max_ps(vmax, _mm256_andnot_ps(sign_mask, x)); + } + float absmax = _mm256_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m256 vscale = _mm256_set1_ps(s); + i = 0; + for (; i + 7 < width; i += 8) + { + __m256 x = _mm256_loadu_ps(src + i); + __m256i y = _mm256_cvtps_epi32(_mm256_mul_ps(x, vscale)); + __m128i a0 = _mm256_extracti128_si256(y, 0); + __m128i a1 = _mm256_extracti128_si256(y, 1); + __m128i b0 = _mm_packs_epi32(a0, a1); + _mm_storel_epi64((__m128i*)(dst + i), _mm_packs_epi16(b0, b0)); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +#if __AVX512F__ +static inline void dynamic_quantize_rowwise_avx512(const float* src, signed char* dst, float* scale, int width) +{ + __m512 sign_mask = _mm512_set1_ps(-0.f); + __m512 vmax = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < width; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + vmax = _mm512_max_ps(vmax, _mm512_andnot_ps(sign_mask, x)); + } + float absmax = _mm512_comp_reduce_max_ps(vmax); + for (; i < width; i++) + { + absmax = std::max(absmax, fabsf(src[i])); + } + float s = absmax == 0.f ? 1.f : 127.f / absmax; + *scale = s; + __m512 vscale = _mm512_set1_ps(s); + i = 0; + for (; i + 15 < width; i += 16) + { + __m512 x = _mm512_loadu_ps(src + i); + __m512i y = _mm512_cvtps_epi32(_mm512_mul_ps(x, vscale)); + __m128i a0 = _mm512_extracti32x4_epi32(y, 0); + __m128i a1 = _mm512_extracti32x4_epi32(y, 1); + __m128i a2 = _mm512_extracti32x4_epi32(y, 2); + __m128i a3 = _mm512_extracti32x4_epi32(y, 3); + __m128i b0 = _mm_packs_epi32(a0, a1); + __m128i b1 = _mm_packs_epi32(a2, a3); + __m128i c0 = _mm_packs_epi16(b0, b1); + _mm_storeu_si128((__m128i*)(dst + i), c0); + } + for (; i < width; i++) + { + dst[i] = float2int8(src[i] * s); + } +} +#endif + +static void dynamic_quantize_rowwise_dispatch(const float* src, signed char* dst, float* scale, int width) +{ +#if __AVX512F__ + dynamic_quantize_rowwise_avx512(src, dst, scale, width); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + dynamic_quantize_rowwise_avx2(src, dst, scale, width); + return; + } +#endif +#if __AVX2__ + dynamic_quantize_rowwise_avx2(src, dst, scale, width); +#elif __SSE2__ + dynamic_quantize_rowwise_sse2(src, dst, scale, width); +#else + dynamic_quantize_rowwise_scalar(src, dst, scale, width); +#endif +#endif +} + +static inline void reciprocal_scales(float* scales, int num_blocks) +{ + int i = 0; +#if __AVX512F__ + for (; i + 15 < num_blocks; i += 16) + { + __m512 v = _mm512_loadu_ps(scales + i); + _mm512_storeu_ps(scales + i, _mm512_div_ps(_mm512_set1_ps(1.0f), v)); + } +#elif __AVX__ + for (; i + 7 < num_blocks; i += 8) + { + __m256 v = _mm256_loadu_ps(scales + i); + _mm256_storeu_ps(scales + i, _mm256_div_ps(_mm256_set1_ps(1.0f), v)); + } +#elif __SSE2__ + for (; i + 3 < num_blocks; i += 4) + { + __m128 v = _mm_loadu_ps(scales + i); + _mm_storeu_ps(scales + i, _mm_div_ps(_mm_set1_ps(1.0f), v)); + } +#endif + for (; i < num_blocks; i++) + { + scales[i] = 1.0f / scales[i]; + } +} + +// =================== Int8 SIMD Kernels =================== + +static inline int qk_int8_dot_block_scalar(const signed char* a, const signed char* b, int len) +{ + int sum = 0; + for (int i = 0; i < len; i++) + sum += a[i] * b[i]; + return sum; +} + +#if __SSE2__ && !__SSSE3__ +// SSE2-compatible int8 sign-extension helpers (SSSE3 provides _mm_cvtepi8_epi16) +static inline __m128i _mm_cvtepi8_epi16_sse2(__m128i x) +{ + __m128i sign = _mm_cmpgt_epi8(_mm_setzero_si128(), x); + return _mm_unpacklo_epi8(x, sign); +} +#define _mm_cvtepi8_epi16 _mm_cvtepi8_epi16_sse2 +#endif +#if __SSE2__ && !__SSE4_1__ +static inline __m128i _mm_cvtepi8_epi32_sse2(__m128i x) +{ + __m128i sign = _mm_cmpgt_epi8(_mm_setzero_si128(), x); + __m128i x16 = _mm_unpacklo_epi8(x, sign); + __m128i sign16 = _mm_cmpgt_epi16(_mm_setzero_si128(), x16); + return _mm_unpacklo_epi16(x16, sign16); +} +#define _mm_cvtepi8_epi32 _mm_cvtepi8_epi32_sse2 +#endif + +#if __SSE2__ +static inline int qk_int8_dot_block_sse2(const signed char* a, const signed char* b, int len) +{ + __m128i sum = _mm_setzero_si128(); + int i = 0; + for (; i + 15 < len; i += 16) + { + __m128i va = _mm_loadu_si128((const __m128i*)(a + i)); + __m128i vb = _mm_loadu_si128((const __m128i*)(b + i)); + __m128i va_lo = _mm_cvtepi8_epi16(va); + __m128i va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8)); + __m128i vb_lo = _mm_cvtepi8_epi16(vb); + __m128i vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8)); + sum = _mm_add_epi32(sum, _mm_madd_epi16(va_lo, vb_lo)); + sum = _mm_add_epi32(sum, _mm_madd_epi16(va_hi, vb_hi)); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + return _mm_reduce_add_epi32(sum) + sum_tail; +} +#endif // __SSE2__ + +#if __AVX2__ +inline int __attribute__((noinline)) qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len) +{ + __m256i sum = _mm256_setzero_si256(); + int i = 0; + for (; i + 31 < len; i += 32) + { + __m256i va = _mm256_loadu_si256((const __m256i*)(a + i)); + __m256i vb = _mm256_loadu_si256((const __m256i*)(b + i)); + __m256i va_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(va)); + __m256i va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1)); + __m256i vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb)); + __m256i vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1)); + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(va_lo, vb_lo)); + sum = _mm256_add_epi32(sum, _mm256_madd_epi16(va_hi, vb_hi)); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum, 1)); + return _mm_reduce_add_epi32(sum128) + sum_tail; +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline int qk_int8_dot_block_avx512(const signed char* a, const signed char* b, int len) +{ + __m512i sum = _mm512_setzero_si512(); + int i = 0; + for (; i + 31 < len; i += 32) + { + __m256i va256 = _mm256_loadu_si256((const __m256i*)(a + i)); + __m256i vb256 = _mm256_loadu_si256((const __m256i*)(b + i)); + __m512i va = _mm512_cvtepi8_epi16(va256); + __m512i vb = _mm512_cvtepi8_epi16(vb256); + sum = _mm512_add_epi32(sum, _mm512_madd_epi16(va, vb)); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + __m256i sum256 = _mm256_add_epi32(_mm512_castsi512_si256(sum), _mm512_extracti32x8_epi32(sum, 1)); + __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(sum256), _mm256_extracti128_si256(sum256, 1)); + return _mm_reduce_add_epi32(sum128) + sum_tail; +} +#endif // __AVX512F__ + +static inline int qk_int8_dot_block(const signed char* a, const signed char* b, int len) +{ +#if __AVX512F__ + return qk_int8_dot_block_avx512(a, b, len); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + return qk_int8_dot_block_avx2(a, b, len); + } +#endif +#if __AVX2__ + return qk_int8_dot_block_avx2(a, b, len); +#elif __SSE2__ + return qk_int8_dot_block_sse2(a, b, len); +#else + return qk_int8_dot_block_scalar(a, b, len); +#endif +#endif +} + +// ------------------- Decode QK Dot Int8 ------------------- + +static inline void decode_qk_dot_int8_scalar(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + for (int j = 0; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j) * num_blocks; + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_scalar(q + off, kptr + off, len); + sum += (float)block_sum / (qscales[b] * ks[b]); + } + s[j] = sum * scale; + } +} + +#if __SSE2__ +static inline void decode_qk_dot_int8_sse2(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const float* ks0 = kscales + (n_start + j + 0) * num_blocks; + const float* ks1 = kscales + (n_start + j + 1) * num_blocks; + + float sum0 = 0.f, sum1 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + float descale = qscales[b] * ks0[b]; + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_sse2(q + off, k0 + off, len); + sum0 += (float)block_sum * descale; + } + for (int b = 0; b < num_blocks; b++) + { + float descale = qscales[b] * ks1[b]; + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_sse2(q + off, k1 + off, len); + sum1 += (float)block_sum * descale; + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j) * num_blocks; + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_scalar(q + off, kptr + off, len); + float descale = qscales[b] * ks[b]; + sum += (float)block_sum * descale; + } + s[j] = sum * scale; + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) decode_qk_dot_int8_avx2(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const float* ks0 = kscales + (n_start + j + 0) * num_blocks; + const float* ks1 = kscales + (n_start + j + 1) * num_blocks; + + float sum0 = 0.f, sum1 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum0 = qk_int8_dot_block_avx2(q + off, k0 + off, len); + int block_sum1 = qk_int8_dot_block_avx2(q + off, k1 + off, len); + sum0 += (float)block_sum0 / (qscales[b] * ks0[b]); + sum1 += (float)block_sum1 / (qscales[b] * ks1[b]); + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j) * num_blocks; + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_avx2(q + off, kptr + off, len); + float descale = qscales[b] * ks[b]; + sum += (float)block_sum * descale; + } + s[j] = sum * scale; + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void decode_qk_dot_int8_avx512(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const signed char* k2 = K + (n_start + j + 2) * d; + const signed char* k3 = K + (n_start + j + 3) * d; + const float* ks0 = kscales + (n_start + j + 0) * num_blocks; + const float* ks1 = kscales + (n_start + j + 1) * num_blocks; + const float* ks2 = kscales + (n_start + j + 2) * num_blocks; + const float* ks3 = kscales + (n_start + j + 3) * num_blocks; + + float sum0 = 0.f, sum1 = 0.f, sum2 = 0.f, sum3 = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int bs; + bs = qk_int8_dot_block_avx512(q + off, k0 + off, len); + sum0 += (float)bs / (qscales[b] * ks0[b]); + bs = qk_int8_dot_block_avx512(q + off, k1 + off, len); + sum1 += (float)bs / (qscales[b] * ks1[b]); + bs = qk_int8_dot_block_avx512(q + off, k2 + off, len); + sum2 += (float)bs / (qscales[b] * ks2[b]); + bs = qk_int8_dot_block_avx512(q + off, k3 + off, len); + sum3 += (float)bs / (qscales[b] * ks3[b]); + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + s[j + 2] = sum2 * scale; + s[j + 3] = sum3 * scale; + } + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + const float* ks = kscales + (n_start + j) * num_blocks; + float sum = 0.f; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int block_sum = qk_int8_dot_block_avx512(q + off, kptr + off, len); + float descale = qscales[b] * ks[b]; + sum += (float)block_sum * descale; + } + s[j] = sum * scale; + } +} +#endif // __AVX512F__ + +static void decode_qk_dot_int8_dispatch(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ +#if __AVX512F__ + decode_qk_dot_int8_avx512(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + decode_qk_dot_int8_avx2(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if __AVX2__ + decode_qk_dot_int8_avx2(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#elif __SSE2__ + decode_qk_dot_int8_sse2(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#else + decode_qk_dot_int8_scalar(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#endif +#endif +} + +// ------------------- Prefill QK Int8 GEMM Row-wise ------------------- + +static inline void qk_int8_gemm_row_scalar(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + for (int j = 0; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_scalar(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} + +#if __SSE2__ +static inline void qk_int8_gemm_row_sse2(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_sse2(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_sse2(q_row + off, k1 + off, len); + } + s_row[j] = (float)sum0 * qscale * kscales[j] * scale; + s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_scalar(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) qk_int8_gemm_row_avx2(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_avx2(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_avx2(q_row + off, k1 + off, len); + } + s_row[j] = (float)sum0 * qscale * kscales[j] * scale; + s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_avx2(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void qk_int8_gemm_row_avx512(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q_row + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); + __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); + + acc0 = _mm512_add_epi32(acc0, _mm512_madd_epi16(q_512, k0_512)); + acc1 = _mm512_add_epi32(acc1, _mm512_madd_epi16(q_512, k1_512)); + acc2 = _mm512_add_epi32(acc2, _mm512_madd_epi16(q_512, k2_512)); + acc3 = _mm512_add_epi32(acc3, _mm512_madd_epi16(q_512, k3_512)); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512(q_row + off, k0 + off, len); + scalar0 += bs; + bs = qk_int8_dot_block_avx512(q_row + off, k1 + off, len); + scalar1 += bs; + bs = qk_int8_dot_block_avx512(q_row + off, k2 + off, len); + scalar2 += bs; + bs = qk_int8_dot_block_avx512(q_row + off, k3 + off, len); + scalar3 += bs; + } + } + float descale0 = qscale * kscales[j]; + float descale1 = qscale * kscales[j + 1]; + float descale2 = qscale * kscales[j + 2]; + float descale3 = qscale * kscales[j + 3]; + float sum0 = (float)_mm512_reduce_add_epi32(acc0) + (float)scalar0; + float sum1 = (float)_mm512_reduce_add_epi32(acc1) + (float)scalar1; + float sum2 = (float)_mm512_reduce_add_epi32(acc2) + (float)scalar2; + float sum3 = (float)_mm512_reduce_add_epi32(acc3) + (float)scalar3; + s_row[j] = sum0 * descale0 * scale; + s_row[j + 1] = sum1 * descale1 * scale; + s_row[j + 2] = sum2 * descale2 * scale; + s_row[j + 3] = sum3 * descale3 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + __m512i acc = _mm512_setzero_epi32(); + int scalar_sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q_row + off)); + __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + __m512i k_512 = _mm512_cvtepi8_epi16(k_32); + acc = _mm512_add_epi32(acc, _mm512_madd_epi16(q_512, k_512)); + } + else + { + scalar_sum += qk_int8_dot_block_avx512(q_row + off, kptr + off, len); + } + } + float sum = (float)_mm512_reduce_add_epi32(acc) + (float)scalar_sum; + s_row[j] = sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX512F__ + +static void qk_int8_gemm_row_dispatch(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ +#if __AVX512F__ + qk_int8_gemm_row_avx512(s_row, q_row, K, qscale, kscales, n, d, scale); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + qk_int8_gemm_row_avx2(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if __AVX2__ + qk_int8_gemm_row_avx2(s_row, q_row, K, qscale, kscales, n, d, scale); +#elif __SSE2__ + qk_int8_gemm_row_sse2(s_row, q_row, K, qscale, kscales, n, d, scale); +#else + qk_int8_gemm_row_scalar(s_row, q_row, K, qscale, kscales, n, d, scale); +#endif +#endif +} + +// ------------------- Tiled QK Int8 GEMM (M-tiling) ------------------- + +static inline void qk_int8_gemm_tiled_scalar(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + for (int j = 0; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_scalar(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_scalar(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_scalar(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} + +#if __SSE2__ +static inline void qk_int8_gemm_tiled_sse2(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + int j = 0; + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + int sum00 = 0, sum01 = 0, sum10 = 0, sum11 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int bs; + bs = qk_int8_dot_block_sse2(q0 + off, k0 + off, len); + sum00 += bs; + bs = qk_int8_dot_block_sse2(q0 + off, k1 + off, len); + sum01 += bs; + bs = qk_int8_dot_block_sse2(q1 + off, k0 + off, len); + sum10 += bs; + bs = qk_int8_dot_block_sse2(q1 + off, k1 + off, len); + sum11 += bs; + } + S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; + S[(i + 0) * n + j + 1] = (float)sum01 * qs0 * kscales[j + 1] * scale; + S[(i + 1) * n + j] = (float)sum10 * qs1 * kscales[j] * scale; + S[(i + 1) * n + j + 1] = (float)sum11 * qs1 * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_sse2(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_sse2(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_sse2(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __SSE2__ + +#if __AVX2__ +static inline int _mm256_reduce_add_epi32(__m256i v) +{ + __m128i vlow = _mm256_castsi256_si128(v); + __m128i vhigh = _mm256_extracti128_si256(v, 1); + vlow = _mm_add_epi32(vlow, vhigh); + return _mm_reduce_add_epi32(vlow); +} + +#if __AVX512F__ +static inline float _mm512_reduce_add_ps(__m512 v) +{ + __m256 vlow = _mm512_castps512_ps256(v); + __m256 vhigh = _mm512_extractf32x8_ps(v, 1); + return _mm256_reduce_add_ps(_mm256_add_ps(vlow, vhigh)); +} + +static inline int _mm512_reduce_add_epi32(__m512i v) +{ + __m256i vlow = _mm512_castsi512_si256(v); + __m256i vhigh = _mm512_extracti32x8_epi32(v, 1); + return _mm256_reduce_add_epi32(_mm256_add_epi32(vlow, vhigh)); +} +#endif // __AVX512F__ + +inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 2 <= m; i += 2) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m256i acc00 = _mm256_setzero_si256(); + __m256i acc01 = _mm256_setzero_si256(); + __m256i acc02 = _mm256_setzero_si256(); + __m256i acc03 = _mm256_setzero_si256(); + __m256i acc10 = _mm256_setzero_si256(); + __m256i acc11 = _mm256_setzero_si256(); + __m256i acc12 = _mm256_setzero_si256(); + __m256i acc13 = _mm256_setzero_si256(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q0_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q0_32)); + __m256i q0_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q0_32, 1)); + __m256i q1_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(q1_32)); + __m256i q1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(q1_32, 1)); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k0_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k0_32)); + __m256i k0_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k0_32, 1)); + acc00 = _mm256_add_epi32(acc00, _mm256_madd_epi16(q0_lo, k0_lo)); + acc00 = _mm256_add_epi32(acc00, _mm256_madd_epi16(q0_hi, k0_hi)); + acc10 = _mm256_add_epi32(acc10, _mm256_madd_epi16(q1_lo, k0_lo)); + acc10 = _mm256_add_epi32(acc10, _mm256_madd_epi16(q1_hi, k0_hi)); + + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k1_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k1_32)); + __m256i k1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k1_32, 1)); + acc01 = _mm256_add_epi32(acc01, _mm256_madd_epi16(q0_lo, k1_lo)); + acc01 = _mm256_add_epi32(acc01, _mm256_madd_epi16(q0_hi, k1_hi)); + acc11 = _mm256_add_epi32(acc11, _mm256_madd_epi16(q1_lo, k1_lo)); + acc11 = _mm256_add_epi32(acc11, _mm256_madd_epi16(q1_hi, k1_hi)); + + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k2_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k2_32)); + __m256i k2_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k2_32, 1)); + acc02 = _mm256_add_epi32(acc02, _mm256_madd_epi16(q0_lo, k2_lo)); + acc02 = _mm256_add_epi32(acc02, _mm256_madd_epi16(q0_hi, k2_hi)); + acc12 = _mm256_add_epi32(acc12, _mm256_madd_epi16(q1_lo, k2_lo)); + acc12 = _mm256_add_epi32(acc12, _mm256_madd_epi16(q1_hi, k2_hi)); + + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m256i k3_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k3_32)); + __m256i k3_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k3_32, 1)); + acc03 = _mm256_add_epi32(acc03, _mm256_madd_epi16(q0_lo, k3_lo)); + acc03 = _mm256_add_epi32(acc03, _mm256_madd_epi16(q0_hi, k3_hi)); + acc13 = _mm256_add_epi32(acc13, _mm256_madd_epi16(q1_lo, k3_lo)); + acc13 = _mm256_add_epi32(acc13, _mm256_madd_epi16(q1_hi, k3_hi)); + } + else + { + int bs; + bs = qk_int8_dot_block_avx2(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx2(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx2(q0 + off, k2 + off, len); + scalar02 += bs; + bs = qk_int8_dot_block_avx2(q0 + off, k3 + off, len); + scalar03 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k2 + off, len); + scalar12 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k3 + off, len); + scalar13 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + S[(i + 0) * n + j] = ((float)_mm256_reduce_add_epi32(acc00) + (float)scalar00) * descale00 * scale; + S[(i + 0) * n + j + 1] = ((float)_mm256_reduce_add_epi32(acc01) + (float)scalar01) * descale01 * scale; + S[(i + 0) * n + j + 2] = ((float)_mm256_reduce_add_epi32(acc02) + (float)scalar02) * descale02 * scale; + S[(i + 0) * n + j + 3] = ((float)_mm256_reduce_add_epi32(acc03) + (float)scalar03) * descale03 * scale; + S[(i + 1) * n + j] = ((float)_mm256_reduce_add_epi32(acc10) + (float)scalar10) * descale10 * scale; + S[(i + 1) * n + j + 1] = ((float)_mm256_reduce_add_epi32(acc11) + (float)scalar11) * descale11 * scale; + S[(i + 1) * n + j + 2] = ((float)_mm256_reduce_add_epi32(acc12) + (float)scalar12) * descale12 * scale; + S[(i + 1) * n + j + 3] = ((float)_mm256_reduce_add_epi32(acc13) + (float)scalar13) * descale13 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + int sum00 = 0, sum01 = 0, sum10 = 0, sum11 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + int bs; + bs = qk_int8_dot_block_avx2(q0 + off, k0 + off, len); + sum00 += bs; + bs = qk_int8_dot_block_avx2(q0 + off, k1 + off, len); + sum01 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k0 + off, len); + sum10 += bs; + bs = qk_int8_dot_block_avx2(q1 + off, k1 + off, len); + sum11 += bs; + } + S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; + S[(i + 0) * n + j + 1] = (float)sum01 * qs0 * kscales[j + 1] * scale; + S[(i + 1) * n + j] = (float)sum10 * qs1 * kscales[j] * scale; + S[(i + 1) * n + j + 1] = (float)sum11 * qs1 * kscales[j + 1] * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum0 = 0, sum1 = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum0 += qk_int8_dot_block_avx2(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_avx2(q1 + off, kptr + off, len); + } + S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; + S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx2(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void qk_int8_gemm_tiled_avx512(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int i = 0; + for (; i + 4 <= m; i += 4) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + const signed char* q2 = Q + (i + 2) * d; + const signed char* q3 = Q + (i + 3) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + float qs2 = qscales[i + 2]; + float qs3 = qscales[i + 3]; + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc02 = _mm512_setzero_epi32(); + __m512i acc03 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc12 = _mm512_setzero_epi32(); + __m512i acc13 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc22 = _mm512_setzero_epi32(); + __m512i acc23 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + __m512i acc32 = _mm512_setzero_epi32(); + __m512i acc33 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + int scalar20 = 0, scalar21 = 0, scalar22 = 0, scalar23 = 0; + int scalar30 = 0, scalar31 = 0, scalar32 = 0, scalar33 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); + __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); + __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); + + acc00 = _mm512_add_epi32(acc00, _mm512_madd_epi16(q0_512, k0_512)); + acc01 = _mm512_add_epi32(acc01, _mm512_madd_epi16(q0_512, k1_512)); + acc02 = _mm512_add_epi32(acc02, _mm512_madd_epi16(q0_512, k2_512)); + acc03 = _mm512_add_epi32(acc03, _mm512_madd_epi16(q0_512, k3_512)); + acc10 = _mm512_add_epi32(acc10, _mm512_madd_epi16(q1_512, k0_512)); + acc11 = _mm512_add_epi32(acc11, _mm512_madd_epi16(q1_512, k1_512)); + acc12 = _mm512_add_epi32(acc12, _mm512_madd_epi16(q1_512, k2_512)); + acc13 = _mm512_add_epi32(acc13, _mm512_madd_epi16(q1_512, k3_512)); + acc20 = _mm512_add_epi32(acc20, _mm512_madd_epi16(q2_512, k0_512)); + acc21 = _mm512_add_epi32(acc21, _mm512_madd_epi16(q2_512, k1_512)); + acc22 = _mm512_add_epi32(acc22, _mm512_madd_epi16(q2_512, k2_512)); + acc23 = _mm512_add_epi32(acc23, _mm512_madd_epi16(q2_512, k3_512)); + acc30 = _mm512_add_epi32(acc30, _mm512_madd_epi16(q3_512, k0_512)); + acc31 = _mm512_add_epi32(acc31, _mm512_madd_epi16(q3_512, k1_512)); + acc32 = _mm512_add_epi32(acc32, _mm512_madd_epi16(q3_512, k2_512)); + acc33 = _mm512_add_epi32(acc33, _mm512_madd_epi16(q3_512, k3_512)); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx512(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx512(q0 + off, k2 + off, len); + scalar02 += bs; + bs = qk_int8_dot_block_avx512(q0 + off, k3 + off, len); + scalar03 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k2 + off, len); + scalar12 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k3 + off, len); + scalar13 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k0 + off, len); + scalar20 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k1 + off, len); + scalar21 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k2 + off, len); + scalar22 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k3 + off, len); + scalar23 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k0 + off, len); + scalar30 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k1 + off, len); + scalar31 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k2 + off, len); + scalar32 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k3 + off, len); + scalar33 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale22 = qs2 * kscales[j + 2]; + float descale23 = qs2 * kscales[j + 3]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float descale32 = qs3 * kscales[j + 2]; + float descale33 = qs3 * kscales[j + 3]; + float sum00 = (float)_mm512_reduce_add_epi32(acc00) + (float)scalar00; + float sum01 = (float)_mm512_reduce_add_epi32(acc01) + (float)scalar01; + float sum02 = (float)_mm512_reduce_add_epi32(acc02) + (float)scalar02; + float sum03 = (float)_mm512_reduce_add_epi32(acc03) + (float)scalar03; + float sum10 = (float)_mm512_reduce_add_epi32(acc10) + (float)scalar10; + float sum11 = (float)_mm512_reduce_add_epi32(acc11) + (float)scalar11; + float sum12 = (float)_mm512_reduce_add_epi32(acc12) + (float)scalar12; + float sum13 = (float)_mm512_reduce_add_epi32(acc13) + (float)scalar13; + float sum20 = (float)_mm512_reduce_add_epi32(acc20) + (float)scalar20; + float sum21 = (float)_mm512_reduce_add_epi32(acc21) + (float)scalar21; + float sum22 = (float)_mm512_reduce_add_epi32(acc22) + (float)scalar22; + float sum23 = (float)_mm512_reduce_add_epi32(acc23) + (float)scalar23; + float sum30 = (float)_mm512_reduce_add_epi32(acc30) + (float)scalar30; + float sum31 = (float)_mm512_reduce_add_epi32(acc31) + (float)scalar31; + float sum32 = (float)_mm512_reduce_add_epi32(acc32) + (float)scalar32; + float sum33 = (float)_mm512_reduce_add_epi32(acc33) + (float)scalar33; + S[(i + 0) * n + j + 0] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 0) * n + j + 2] = sum02 * descale02 * scale; + S[(i + 0) * n + j + 3] = sum03 * descale03 * scale; + S[(i + 1) * n + j + 0] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 1) * n + j + 2] = sum12 * descale12 * scale; + S[(i + 1) * n + j + 3] = sum13 * descale13 * scale; + S[(i + 2) * n + j + 0] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 2) * n + j + 2] = sum22 * descale22 * scale; + S[(i + 2) * n + j + 3] = sum23 * descale23 * scale; + S[(i + 3) * n + j + 0] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + S[(i + 3) * n + j + 2] = sum32 * descale32 * scale; + S[(i + 3) * n + j + 3] = sum33 * descale33 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar10 = 0, scalar11 = 0; + int scalar20 = 0, scalar21 = 0, scalar30 = 0, scalar31 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); + __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); + __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); + __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); + + acc00 = _mm512_add_epi32(acc00, _mm512_madd_epi16(q0_512, k0_512)); + acc01 = _mm512_add_epi32(acc01, _mm512_madd_epi16(q0_512, k1_512)); + acc10 = _mm512_add_epi32(acc10, _mm512_madd_epi16(q1_512, k0_512)); + acc11 = _mm512_add_epi32(acc11, _mm512_madd_epi16(q1_512, k1_512)); + acc20 = _mm512_add_epi32(acc20, _mm512_madd_epi16(q2_512, k0_512)); + acc21 = _mm512_add_epi32(acc21, _mm512_madd_epi16(q2_512, k1_512)); + acc30 = _mm512_add_epi32(acc30, _mm512_madd_epi16(q3_512, k0_512)); + acc31 = _mm512_add_epi32(acc31, _mm512_madd_epi16(q3_512, k1_512)); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512(q0 + off, k0 + off, len); + scalar00 += bs; + bs = qk_int8_dot_block_avx512(q0 + off, k1 + off, len); + scalar01 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k0 + off, len); + scalar10 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, k1 + off, len); + scalar11 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k0 + off, len); + scalar20 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, k1 + off, len); + scalar21 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k0 + off, len); + scalar30 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, k1 + off, len); + scalar31 += bs; + } + } + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float sum00 = (float)_mm512_reduce_add_epi32(acc00) + (float)scalar00; + float sum01 = (float)_mm512_reduce_add_epi32(acc01) + (float)scalar01; + float sum10 = (float)_mm512_reduce_add_epi32(acc10) + (float)scalar10; + float sum11 = (float)_mm512_reduce_add_epi32(acc11) + (float)scalar11; + float sum20 = (float)_mm512_reduce_add_epi32(acc20) + (float)scalar20; + float sum21 = (float)_mm512_reduce_add_epi32(acc21) + (float)scalar21; + float sum30 = (float)_mm512_reduce_add_epi32(acc30) + (float)scalar30; + float sum31 = (float)_mm512_reduce_add_epi32(acc31) + (float)scalar31; + S[(i + 0) * n + j] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 1) * n + j] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 2) * n + j] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 3) * n + j] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + if (len == 32) + { + __m256i q0_32 = _mm256_loadu_si256((const __m256i*)(q0 + off)); + __m256i q1_32 = _mm256_loadu_si256((const __m256i*)(q1 + off)); + __m256i q2_32 = _mm256_loadu_si256((const __m256i*)(q2 + off)); + __m256i q3_32 = _mm256_loadu_si256((const __m256i*)(q3 + off)); + __m512i q0_512 = _mm512_cvtepi8_epi16(q0_32); + __m512i q1_512 = _mm512_cvtepi8_epi16(q1_32); + __m512i q2_512 = _mm512_cvtepi8_epi16(q2_32); + __m512i q3_512 = _mm512_cvtepi8_epi16(q3_32); + + __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); + __m512i k_512 = _mm512_cvtepi8_epi16(k_32); + + acc0 = _mm512_add_epi32(acc0, _mm512_madd_epi16(q0_512, k_512)); + acc1 = _mm512_add_epi32(acc1, _mm512_madd_epi16(q1_512, k_512)); + acc2 = _mm512_add_epi32(acc2, _mm512_madd_epi16(q2_512, k_512)); + acc3 = _mm512_add_epi32(acc3, _mm512_madd_epi16(q3_512, k_512)); + } + else + { + int bs; + bs = qk_int8_dot_block_avx512(q0 + off, kptr + off, len); + scalar0 += bs; + bs = qk_int8_dot_block_avx512(q1 + off, kptr + off, len); + scalar1 += bs; + bs = qk_int8_dot_block_avx512(q2 + off, kptr + off, len); + scalar2 += bs; + bs = qk_int8_dot_block_avx512(q3 + off, kptr + off, len); + scalar3 += bs; + } + } + float descale = kscales[j]; + float sum0 = (float)_mm512_reduce_add_epi32(acc0) + (float)scalar0; + float sum1 = (float)_mm512_reduce_add_epi32(acc1) + (float)scalar1; + float sum2 = (float)_mm512_reduce_add_epi32(acc2) + (float)scalar2; + float sum3 = (float)_mm512_reduce_add_epi32(acc3) + (float)scalar3; + S[(i + 0) * n + j] = sum0 * qs0 * descale * scale; + S[(i + 1) * n + j] = sum1 * qs1 * descale * scale; + S[(i + 2) * n + j] = sum2 * qs2 * descale * scale; + S[(i + 3) * n + j] = sum3 * qs3 * descale * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx512(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX512F__ + +static void qk_int8_gemm_tiled_dispatch(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ +#if __AVX512F__ + qk_int8_gemm_tiled_avx512(S, Q, K, qscales, kscales, m, n, d, scale); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + qk_int8_gemm_tiled_avx2(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if __AVX2__ + qk_int8_gemm_tiled_avx2(S, Q, K, qscales, kscales, m, n, d, scale); +#elif __SSE2__ + qk_int8_gemm_tiled_sse2(S, Q, K, qscales, kscales, m, n, d, scale); +#else + qk_int8_gemm_tiled_scalar(S, Q, K, qscales, kscales, m, n, d, scale); +#endif +#endif +} + +// ------------------- Decode PV GEMV Int8 ------------------- + +static inline void decode_pv_gemv_int8_scalar(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + for (int k = k_start; k < k_end; k++) + out[k] = 0.f; + + for (int j = 0; j < block_n; j++) + { + float p = s[j]; + float inv_scale = 1.f / vscales[(n_start + j) * num_blocks + vb]; + const signed char* vptr = V + (n_start + j) * out_d + k_start; + for (int k = k_start; k < k_end; k++) + out[k] += p * (float)vptr[k - k_start] * inv_scale; + } + } +} + +#if __SSE2__ +static inline void decode_pv_gemv_int8_sse2(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0, + V[(n_start+j)*out_d+k+3], V[(n_start+j)*out_d+k+2], + V[(n_start+j)*out_d+k+1], V[(n_start+j)*out_d+k+0]); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + sum += p_invscale * V[(n_start + j) * out_d + k]; + } + out[k] = sum; + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) decode_pv_gemv_int8_avx2(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + + int k = k_start; + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + (n_start + j) * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < block_n; j++) + sum += s[j] / vscales[(n_start + j) * num_blocks + vb] * V[(n_start + j) * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX__ + +#if __AVX512F__ +static inline void decode_pv_gemv_int8_avx512(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + + int k = k_start; + for (; k + 15 < k_end; k += 16) + { + __m512 oval = _mm512_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m512 pvec = _mm512_set1_ps(p_invscale); + __m128i v8 = _mm_loadu_si128((const __m128i*)(V + (n_start + j) * out_d + k)); + __m512i v32 = _mm512_cvtepi8_epi32(v8); + __m512 vfp = _mm512_cvtepi32_ps(v32); + oval = _mm512_fmadd_ps(pvec, vfp, oval); + } + _mm512_storeu_ps(out + k, oval); + } + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < block_n; j++) + sum += s[j] / vscales[(n_start + j) * num_blocks + vb] * V[(n_start + j) * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX512F__ + +static void decode_pv_gemv_int8_dispatch(float* out, const float* s, + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) +{ +#if __AVX512F__ + decode_pv_gemv_int8_avx512(out, s, V, vscales, n_start, block_n, out_d); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + decode_pv_gemv_int8_avx2(out, s, V, vscales, n_start, block_n, out_d); + return; + } +#endif +#if __AVX2__ + decode_pv_gemv_int8_avx2(out, s, V, vscales, n_start, block_n, out_d); +#elif __SSE2__ + decode_pv_gemv_int8_sse2(out, s, V, vscales, n_start, block_n, out_d); +#else + decode_pv_gemv_int8_scalar(out, s, V, vscales, n_start, block_n, out_d); +#endif +#endif +} + +// ------------------- Prefill PV Float×Int8 GEMM Row-wise ------------------- + +static inline void pv_float_int8_gemm_row_scalar(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + for (int k = k_start; k < k_end; k++) + out[k] = 0.f; + + for (int j = 0; j < n; j++) + { + float p = p_row[j]; + float inv_scale = 1.f / vscales[j * num_blocks + vb]; + const signed char* vptr = V + j * out_d + k_start; + for (int k = k_start; k < k_end; k++) + out[k] += p * (float)vptr[k - k_start] * inv_scale; + } + } +} + +#if __SSE2__ +static inline void pv_float_int8_gemm_row_sse2(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + j * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) pv_float_int8_gemm_row_avx2(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + j * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k + 3 < k_end; k += 4) + { + __m128 oval = _mm_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m128 pvec = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(V + j * out_d + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); + } + _mm_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void pv_float_int8_gemm_row_avx512(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ + const int num_blocks = (out_d + 31) / 32; + for (int vb = 0; vb < num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + int k = k_start; + for (; k + 15 < k_end; k += 16) + { + __m512 oval = _mm512_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m512 pvec = _mm512_set1_ps(p_invscale); + __m128i v8 = _mm_loadu_si128((const __m128i*)(V + j * out_d + k)); + __m512i v32 = _mm512_cvtepi8_epi32(v8); + __m512 vfp = _mm512_cvtepi32_ps(v32); + oval = _mm512_fmadd_ps(pvec, vfp, oval); + } + _mm512_storeu_ps(out + k, oval); + } + for (; k + 7 < k_end; k += 8) + { + __m256 oval = _mm256_setzero_ps(); + for (int j = 0; j < n; j++) + { + float p_invscale = p_row[j] / vscales[j * num_blocks + vb]; + __m256 pvec = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + j * out_d + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + oval = _mm256_fmadd_ps(pvec, vfp, oval); + } + _mm256_storeu_ps(out + k, oval); + } + for (; k < k_end; k++) + { + float sum = 0.f; + for (int j = 0; j < n; j++) + sum += p_row[j] / vscales[j * num_blocks + vb] * V[j * out_d + k]; + out[k] = sum; + } + } +} +#endif // __AVX512F__ + +static void pv_float_int8_gemm_row_dispatch(float* out, const float* p_row, + const signed char* V, const float* vscales, + int n, int out_d) +{ +#if __AVX512F__ + pv_float_int8_gemm_row_avx512(out, p_row, V, vscales, n, out_d); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_gemm_row_avx2(out, p_row, V, vscales, n, out_d); + return; + } +#endif +#if __AVX2__ + pv_float_int8_gemm_row_avx2(out, p_row, V, vscales, n, out_d); +#elif __SSE2__ + pv_float_int8_gemm_row_sse2(out, p_row, V, vscales, n, out_d); +#else + pv_float_int8_gemm_row_scalar(out, p_row, V, vscales, n, out_d); +#endif +#endif +} + +// ------------------- PV Float×Int8 FMA Block (online softmax) ------------------- + +static inline void pv_float_int8_fma_block_scalar(float* out, float p_invscale, const signed char* v, int len) +{ + for (int k = 0; k < len; k++) + out[k] += p_invscale * v[k]; +} + +#if __SSE2__ +static inline void pv_float_int8_fma_block_sse2(float* out, float p_invscale, const signed char* v, int len) +{ + __m128 pvec = _mm_set1_ps(p_invscale); + int k = 0; + for (; k + 3 < len; k += 4) + { + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(v + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), _mm_mul_ps(pvec, vfp))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __SSE2__ + +#if __AVX2__ +inline void __attribute__((noinline)) pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len) +{ + __m256 pvec = _mm256_set1_ps(p_invscale); + int k = 0; + for (; k + 7 < len; k += 8) + { + __m128i v8 = _mm_loadl_epi64((const __m128i*)(v + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + _mm256_storeu_ps(out + k, _mm256_fmadd_ps(pvec, vfp, _mm256_loadu_ps(out + k))); + } + for (; k + 3 < len; k += 4) + { + __m128 pvec128 = _mm_set1_ps(p_invscale); + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(v + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), _mm_mul_ps(pvec128, vfp))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __AVX2__ + +#if __AVX512F__ +static inline void pv_float_int8_fma_block_avx512(float* out, float p_invscale, const signed char* v, int len) +{ + __m512 pvec = _mm512_set1_ps(p_invscale); + int k = 0; + for (; k + 15 < len; k += 16) + { + __m128i v8 = _mm_loadu_si128((const __m128i*)(v + k)); + __m512i v32 = _mm512_cvtepi8_epi32(v8); + __m512 vfp = _mm512_cvtepi32_ps(v32); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec, vfp, _mm512_loadu_ps(out + k))); + } + for (; k + 7 < len; k += 8) + { + __m256 pvec256 = _mm256_set1_ps(p_invscale); + __m128i v8 = _mm_loadl_epi64((const __m128i*)(v + k)); + __m256i v32 = _mm256_cvtepi8_epi32(v8); + __m256 vfp = _mm256_cvtepi32_ps(v32); + _mm256_storeu_ps(out + k, _mm256_fmadd_ps(pvec256, vfp, _mm256_loadu_ps(out + k))); + } + for (; k < len; k++) + out[k] += p_invscale * v[k]; +} +#endif // __AVX512F__ + +static void pv_float_int8_fma_block_dispatch(float* out, float p_invscale, const signed char* v, int len) +{ +#if __AVX512F__ + pv_float_int8_fma_block_avx512(out, p_invscale, v, len); +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_fma_block_avx2(out, p_invscale, v, len); + return; + } +#endif +#if __AVX2__ + pv_float_int8_fma_block_avx2(out, p_invscale, v, len); +#elif __SSE2__ + pv_float_int8_fma_block_sse2(out, p_invscale, v, len); +#else + pv_float_int8_fma_block_scalar(out, p_invscale, v, len); +#endif +#endif +} + +#if __AVX2__ +inline void __attribute__((noinline)) pv_float_int8_gemm_tile_avx2(float* O, const float* P, + const signed char* V, const float* vscales, + int block_m, int block_n, int out_embed_dim) +{ + const int v_num_blocks = (out_embed_dim + 31) / 32; +int i = 0; +for (; i + 1 < block_m; i += 2) +{ + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m256 acc0_0 = _mm256_setzero_ps(); + __m256 acc0_1 = _mm256_setzero_ps(); + __m256 acc0_2 = _mm256_setzero_ps(); + __m256 acc0_3 = _mm256_setzero_ps(); + __m256 acc1_0 = _mm256_setzero_ps(); + __m256 acc1_1 = _mm256_setzero_ps(); + __m256 acc1_2 = _mm256_setzero_ps(); + __m256 acc1_3 = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m256 vscale8 = _mm256_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); + __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); + __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); + __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); + __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); + __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); + __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); + __m256 pvec0 = _mm256_set1_ps(p0[j]); + __m256 pvec1 = _mm256_set1_ps(p1[j]); + acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); + acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); + acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); + acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); + acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); + acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); + _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); + _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); + _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); + _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); + _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); + _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); + _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + } + } + } + } +} +for (; i < block_m; i++) +{ + const float* p_row = P + i * block_n; + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + } + } + } +} +} +#endif + +static void pv_float_int8_gemm_tile(float* O, const float* P, + const signed char* V, const float* vscales, + int block_m, int block_n, int out_embed_dim) +{ + const int v_num_blocks = (out_embed_dim + 31) / 32; +#if __AVX512F__ + int i = 0; + for (; i + 3 < block_m; i += 4) + { + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + const float* p2 = P + (i + 2) * block_n; + const float* p3 = P + (i + 3) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m512 acc0_0 = _mm512_setzero_ps(); + __m512 acc0_1 = _mm512_setzero_ps(); + __m512 acc1_0 = _mm512_setzero_ps(); + __m512 acc1_1 = _mm512_setzero_ps(); + __m512 acc2_0 = _mm512_setzero_ps(); + __m512 acc2_1 = _mm512_setzero_ps(); + __m512 acc3_0 = _mm512_setzero_ps(); + __m512 acc3_1 = _mm512_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m512 vscale16 = _mm512_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(vptr + 16)); + __m512 vval_0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), vscale16); + __m512 vval_1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), vscale16); + __m512 pvec0 = _mm512_set1_ps(p0[j]); + __m512 pvec1 = _mm512_set1_ps(p1[j]); + __m512 pvec2 = _mm512_set1_ps(p2[j]); + __m512 pvec3 = _mm512_set1_ps(p3[j]); + acc0_0 = _mm512_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm512_fmadd_ps(pvec0, vval_1, acc0_1); + acc1_0 = _mm512_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm512_fmadd_ps(pvec1, vval_1, acc1_1); + acc2_0 = _mm512_fmadd_ps(pvec2, vval_0, acc2_0); + acc2_1 = _mm512_fmadd_ps(pvec2, vval_1, acc2_1); + acc3_0 = _mm512_fmadd_ps(pvec3, vval_0, acc3_0); + acc3_1 = _mm512_fmadd_ps(pvec3, vval_1, acc3_1); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + float* optr2 = O + (i + 2) * out_embed_dim + k_start; + float* optr3 = O + (i + 3) * out_embed_dim + k_start; + _mm512_storeu_ps(optr0 + 0, _mm512_add_ps(_mm512_loadu_ps(optr0 + 0), acc0_0)); + _mm512_storeu_ps(optr0 + 16, _mm512_add_ps(_mm512_loadu_ps(optr0 + 16), acc0_1)); + _mm512_storeu_ps(optr1 + 0, _mm512_add_ps(_mm512_loadu_ps(optr1 + 0), acc1_0)); + _mm512_storeu_ps(optr1 + 16, _mm512_add_ps(_mm512_loadu_ps(optr1 + 16), acc1_1)); + _mm512_storeu_ps(optr2 + 0, _mm512_add_ps(_mm512_loadu_ps(optr2 + 0), acc2_0)); + _mm512_storeu_ps(optr2 + 16, _mm512_add_ps(_mm512_loadu_ps(optr2 + 16), acc2_1)); + _mm512_storeu_ps(optr3 + 0, _mm512_add_ps(_mm512_loadu_ps(optr3 + 0), acc3_0)); + _mm512_storeu_ps(optr3 + 16, _mm512_add_ps(_mm512_loadu_ps(optr3 + 16), acc3_1)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + O[(i + 2) * out_embed_dim + k_start + k] += p2[j] * vval; + O[(i + 3) * out_embed_dim + k_start + k] += p3[j] * vval; + } + } + } + } + } + for (; i < block_m; i++) + { + const float* p_row = P + i * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m512 vscale16 = _mm512_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(vptr + 16)); + __m512 vval_0 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), vscale16); + __m512 vval_1 = _mm512_mul_ps(_mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), vscale16); + __m512 pvec = _mm512_set1_ps(p_row[j]); + acc0 = _mm512_fmadd_ps(pvec, vval_0, acc0); + acc1 = _mm512_fmadd_ps(pvec, vval_1, acc1); + } + float* optr = O + i * out_embed_dim + k_start; + _mm512_storeu_ps(optr + 0, _mm512_add_ps(_mm512_loadu_ps(optr + 0), acc0)); + _mm512_storeu_ps(optr + 16, _mm512_add_ps(_mm512_loadu_ps(optr + 16), acc1)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k] * vscale; + } + } + } + } + } +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + if (ncnn::cpu_support_x86_avx2()) + { + pv_float_int8_gemm_tile_avx2(O, P, V, vscales, block_m, block_n, out_embed_dim); + return; + } +#endif +#if __AVX2__ + int i = 0; + for (; i + 1 < block_m; i += 2) + { + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) + { + __m256 acc0_0 = _mm256_setzero_ps(); + __m256 acc0_1 = _mm256_setzero_ps(); + __m256 acc0_2 = _mm256_setzero_ps(); + __m256 acc0_3 = _mm256_setzero_ps(); + __m256 acc1_0 = _mm256_setzero_ps(); + __m256 acc1_1 = _mm256_setzero_ps(); + __m256 acc1_2 = _mm256_setzero_ps(); + __m256 acc1_3 = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m256 vscale8 = _mm256_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); + __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); + __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); + __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); + __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); + __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); + __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); + __m256 pvec0 = _mm256_set1_ps(p0[j]); + __m256 pvec1 = _mm256_set1_ps(p1[j]); + acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); + acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); + acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); + acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); + acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); + acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); + _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); + _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); + _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); + _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); + _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); + _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); + _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); + } + else + { + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + } + } + } + } + } + for (; i < block_m; i++) + { + const float* p_row = P + i * block_n; + for (int j = 0; j < block_n; j++) + { + float p = p_row[j]; + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + } + } + } + } +#elif __SSE2__ + for (int j = 0; j < block_n; j++) + { + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + int k = k_start; + for (; k + 3 < k_end; k += 4) + { + __m128i v8 = _mm_cvtsi32_si128(*(const int*)(vptr + k)); + __m128i v32 = _mm_cvtepi8_epi32(v8); + __m128 vfp = _mm_cvtepi32_ps(v32); + __m128 vval = _mm_mul_ps(vfp, _mm_set1_ps(vscale)); + for (int ii = 0; ii < block_m; ii++) + { + float p = P[ii * block_n + j]; + __m128 pvec = _mm_set1_ps(p); + float* optr = O + ii * out_embed_dim + k; + _mm_storeu_ps(optr, _mm_add_ps(_mm_loadu_ps(optr), _mm_mul_ps(pvec, vval))); + } + } + for (; k < k_end; k++) + { + float vval = (float)vptr[k] * vscale; + for (int ii = 0; ii < block_m; ii++) + { + O[ii * out_embed_dim + k] += P[ii * block_n + j] * vval; + } + } + } + } +#else + for (int j = 0; j < block_n; j++) + { + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = k_start; k < k_end; k++) + { + float vval = (float)vptr[k] * vscale; + for (int ii = 0; ii < block_m; ii++) + { + O[ii * out_embed_dim + k] += P[ii * block_n + j] * vval; + } + } + } + } +#endif +#endif +} + +#if __SSE2__ && !__SSSE3__ +#undef _mm_cvtepi8_epi16 +#endif +#if __SSE2__ && !__SSE4_1__ +#undef _mm_cvtepi8_epi32 +#endif + + From c654a40d4029d2b64447c725ca19bb01ccb3484c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 3 May 2026 21:41:12 +0800 Subject: [PATCH 27/53] supoort int8 --- src/layer/x86/sdpa_x86.cpp | 377 ++++++- src/layer/x86/sdpa_x86.h | 12 + src/layer/x86/sdpa_x86_avx2.cpp | 66 +- src/layer/x86/sdpa_x86_avx512vnni.cpp | 42 + src/layer/x86/sdpa_x86_avxvnni.cpp | 41 + src/layer/x86/sdpa_x86_avxvnniint8.cpp | 39 + src/layer/x86/sdpa_x86_int8.h | 1384 +++++++++++++++++++----- src/layer/x86/sdpa_x86_xop.cpp | 56 + 8 files changed, 1670 insertions(+), 347 deletions(-) create mode 100644 src/layer/x86/sdpa_x86_avx512vnni.cpp create mode 100644 src/layer/x86/sdpa_x86_avxvnni.cpp create mode 100644 src/layer/x86/sdpa_x86_avxvnniint8.cpp create mode 100644 src/layer/x86/sdpa_x86_xop.cpp diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index c13c7ffac909..6f27a1d8d5c3 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -32,6 +32,12 @@ SDPA_x86::SDPA_x86() #if NCNN_BF16 support_bf16_storage = false; #endif +#if NCNN_INT8 + cached_kv_seqlen = -1; + cached_num_group = 0; + cached_embed_dim = 0; + cached_out_embed_dim = 0; +#endif } int SDPA_x86::create_pipeline(const Option& /*_opt*/) @@ -3614,47 +3620,300 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return -100; #if NCNN_INT8 - if (int8_scale_term) + bool use_int8_path = int8_scale_term; + if (use_int8_path && src_seqlen == 1) + { + // Adaptive threshold: int8 decode is faster only when compute is large enough + // to amortize quantization overhead. For small problems, fall back to fp32. + if (embed_dim <= 128) + use_int8_path = (dst_seqlen >= 64); + else if (embed_dim <= 512) + use_int8_path = (dst_seqlen >= 512); + else + use_int8_path = (dst_seqlen >= 32); + } + + if (use_int8_path) { const int qk_num_blocks = (embed_dim + 31) / 32; const int v_num_blocks = (out_embed_dim + 31) / 32; Mat key_int8(embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); - Mat key_scales(qk_num_blocks, dst_seqlen, num_group, 4u, opt.blob_allocator); + Mat key_scales(1, dst_seqlen, num_group, 4u, opt.blob_allocator); Mat value_int8(out_embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); Mat value_scales(v_num_blocks, dst_seqlen, num_group, 4u, opt.blob_allocator); if (key_int8.empty() || key_scales.empty() || value_int8.empty() || value_scales.empty()) return -100; + bool use_kv_cache = kv_cache && past_seqlen > 0; + bool cache_valid = false; + if (use_kv_cache + && cached_kv_seqlen == past_seqlen + && cached_num_group == num_group + && cached_embed_dim == embed_dim + && cached_out_embed_dim == out_embed_dim + && !cached_key_int8.empty() + && !cached_key_scales.empty() + && !cached_value_int8.empty() + && !cached_value_scales.empty()) + { + cache_valid = true; + for (int g = 0; g < num_group; g++) + { + memcpy(key_int8.channel(g), cached_key_int8.channel(g), embed_dim * past_seqlen); + memcpy(key_scales.channel(g), cached_key_scales.channel(g), past_seqlen * sizeof(float)); + memcpy(value_int8.channel(g), cached_value_int8.channel(g), out_embed_dim * past_seqlen); + memcpy(value_scales.channel(g), cached_value_scales.channel(g), v_num_blocks * past_seqlen * sizeof(float)); + } + } + #pragma omp parallel for num_threads(opt.num_threads) for (int g = 0; g < num_group; g++) { const Mat key_head = key.channel(g); Mat key_int8_head = key_int8.channel(g); Mat key_scales_head = key_scales.channel(g); - for (int j = 0; j < dst_seqlen; j++) + int j_start = cache_valid ? past_seqlen : 0; + for (int j = j_start; j < dst_seqlen; j++) { - dynamic_quantize_blockwise_dispatch(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + dynamic_quantize_rowwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + key_scales_head.row(j)[0] = 1.f / key_scales_head.row(j)[0]; } - const Mat value_head = value.channel(g); - Mat value_int8_head = value_int8.channel(g); - Mat value_scales_head = value_scales.channel(g); - for (int j = 0; j < dst_seqlen; j++) + if (kv_cache) { - dynamic_quantize_blockwise_dispatch(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); + const Mat value_head = value.channel(g); + Mat value_int8_head = value_int8.channel(g); + Mat value_scales_head = value_scales.channel(g); + for (int j = j_start; j < dst_seqlen; j++) + { + dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); + for (int vb = 0; vb < v_num_blocks; vb++) + { + value_scales_head.row(j)[vb] = 1.f / value_scales_head.row(j)[vb]; + } + } } } Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat s_vec(BLOCK_N, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); Mat q_int8_tile(embed_dim, BLOCK_M, opt.num_threads, 1u, opt.workspace_allocator); - Mat q_scales_tile(qk_num_blocks, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + Mat q_scales_tile(1, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - if (o_accum.empty() || s_vec.empty() || q_int8_tile.empty() || q_scales_tile.empty()) + if (o_accum.empty() || s_vec.empty() || p_vec.empty() || q_int8_tile.empty() || q_scales_tile.empty()) return -100; + if (src_seqlen == 1) + { + // Decode path with dedicated int8 GEMV kernels + // For GQA/MQA, group-parallel reduces KV cache contention + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_int8_head = key_int8.channel(g); + const Mat key_scales_head = key_scales.channel(g); + const Mat value_int8_head = value_int8.channel(g); + const Mat value_scales_head = value_scales.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero_dispatch(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + decode_qk_dot_int8(s, q_int8, + key_int8_head.row(0), + q_scale, + key_scales_head.row(0), + n_start, block_n, embed_dim, _scale); + + if (mask_ptr) + { + for (int j = 0; j < block_n; j++) + s[j] += mask_ptr[n_start + j]; + } + + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_dispatch(out, scale_factor, out_embed_dim); + } + + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; + + decode_pv_gemv_int8(out, s, + value_int8_head.row(0), + value_scales_head.row(0), + n_start, block_n, out_embed_dim); + + m = new_m; + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + for (int k = 0; k < out_embed_dim; k++) + outptr[k] = out[k] * inv_l; + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero_dispatch(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_blob; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + decode_qk_dot_int8(s, q_int8, + key_int8_head.row(0), + q_scale, + key_scales_head.row(0), + n_start, block_n, embed_dim, _scale); + + if (mask_ptr) + { + for (int j = 0; j < block_n; j++) + s[j] += mask_ptr[n_start + j]; + } + + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale_dispatch(out, scale_factor, out_embed_dim); + } + + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; + + decode_pv_gemv_int8(out, s, + value_int8_head.row(0), + value_scales_head.row(0), + n_start, block_n, out_embed_dim); + + m = new_m; + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + for (int k = 0; k < out_embed_dim; k++) + outptr[k] = out[k] * inv_l; + } + } + + if (kv_cache && dst_seqlen > 0) + { + cached_key_int8 = key_int8; + cached_key_scales = key_scales; + cached_value_int8 = value_int8; + cached_value_scales = value_scales; + cached_kv_seqlen = dst_seqlen; + cached_num_group = num_group; + cached_embed_dim = embed_dim; + cached_out_embed_dim = out_embed_dim; + } + + return 0; + } + // else: fall through to fp32 decode path below + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_heads; q++) { @@ -3681,6 +3940,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat o_accum_head = o_accum.channel(get_omp_thread_num()); float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); @@ -3691,7 +3951,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int i = 0; i < block_m; i++) { - dynamic_quantize_blockwise_dispatch(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; } for (int i = 0; i < block_m; i++) @@ -3716,44 +3977,43 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; - for (int i = 0; i < block_m; i++) + if (block_m == 1) { - const signed char* qptr = q_int8_tile_head.row(i); - const float* qscales = q_scales_tile_head.row(i); - - for (int j = 0; j < block_n; j++) - { - const signed char* kptr = key_int8_head.row(n_start + j); - const float* kscales = key_scales_head.row(n_start + j); + qk_int8_gemm_row(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0)[0], + key_scales_head.row(n_start), + block_n, embed_dim, _scale); + } + else + { + qk_int8_gemm_tiled(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } - float sum = 0.f; - for (int b = 0; b < qk_num_blocks; b++) - { - int k_start = b * 32; - int k_end = k_start + 32 < embed_dim ? k_start + 32 : embed_dim; - int block_sum = 0; - for (int k = k_start; k < k_end; k++) - { - block_sum += qptr[k] * kptr[k]; - } - sum += (float)block_sum / (qscales[b] * kscales[b]); - } - s_vec_ptr[j] = sum * _scale; - } + for (int i = 0; i < block_m; i++) + { + float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); + float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); if (attn_mask) { const float* mptr = mask_head.row(m_start + i) + n_start; for (int j = 0; j < block_n; j++) { - s_vec_ptr[j] += mptr[j]; + s_row[j] += mptr[j]; } } float m_new = m_vec[i]; for (int j = 0; j < block_n; j++) { - m_new = std::max(m_new, s_vec_ptr[j]); + m_new = std::max(m_new, s_row[j]); } float scale_factor = expf(m_vec[i] - m_new); @@ -3767,26 +4027,27 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int j = 0; j < block_n; j++) { - float p = expf(s_vec_ptr[j] - m_new); - l_new += p; - - const signed char* vptr = value_int8_head.row(n_start + j); - const float* vscales = value_scales_head.row(n_start + j); - for (int vb = 0; vb < v_num_blocks; vb++) - { - float inv_scale = 1.f / vscales[vb]; - int k_start = vb * 32; - int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; - for (int k = k_start; k < k_end; k++) - { - optr[k] += p * (float)vptr[k] * inv_scale; - } - } + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; } m_vec[i] = m_new; l_vec[i] = l_new; } + + if (kv_cache) + { + pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, + value_int8_head.row(n_start), + value_scales_head.row(n_start), + block_m, block_n, out_embed_dim); + } + else + { + pv_gemm_dispatch(o_accum_head.row(0), p_vec_ptr, + value.channel(q / num_heads_per_group).row(n_start), + block_m, block_n, out_embed_dim); + } } for (int i = 0; i < block_m; i++) @@ -3808,6 +4069,18 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to top_blobs[2] = value; } + if (kv_cache && dst_seqlen > 0) + { + cached_key_int8 = key_int8; + cached_key_scales = key_scales; + cached_value_int8 = value_int8; + cached_value_scales = value_scales; + cached_kv_seqlen = dst_seqlen; + cached_num_group = num_group; + cached_embed_dim = embed_dim; + cached_out_embed_dim = out_embed_dim; + } + return 0; } #endif // NCNN_INT8 diff --git a/src/layer/x86/sdpa_x86.h b/src/layer/x86/sdpa_x86.h index 11f329dc0175..08ac3f3966bf 100644 --- a/src/layer/x86/sdpa_x86.h +++ b/src/layer/x86/sdpa_x86.h @@ -17,6 +17,18 @@ class SDPA_x86 : public SDPA virtual int destroy_pipeline(const Option& opt); virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; + +private: +#if NCNN_INT8 + mutable Mat cached_key_int8; + mutable Mat cached_key_scales; + mutable Mat cached_value_int8; + mutable Mat cached_value_scales; + mutable int cached_kv_seqlen; + mutable int cached_num_group; + mutable int cached_embed_dim; + mutable int cached_out_embed_dim; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx2.cpp b/src/layer/x86/sdpa_x86_avx2.cpp index dd1277df7fba..7a80ae4ed765 100644 --- a/src/layer/x86/sdpa_x86_avx2.cpp +++ b/src/layer/x86/sdpa_x86_avx2.cpp @@ -18,23 +18,57 @@ namespace ncnn { #include "sdpa_x86_int8.h" #if __AVX2__ -// force emit inline function symbols for non-avx2 runtime dispatch -static void __attribute__((used)) sdpa_x86_int8_avx2_dummy() -{ - // call through volatile function pointer to force instantiation - void (* volatile f1)(const float*, signed char*, float*, int) = dynamic_quantize_blockwise_avx2; - void (* volatile f2)(const float*, signed char*, float*, int) = dynamic_quantize_rowwise_avx2; - int (* volatile f3)(const signed char*, const signed char*, int) = qk_int8_dot_block_avx2; - void (* volatile f4)(float*, const signed char*, const signed char*, const float*, const float*, int, int, int, float) = decode_qk_dot_int8_avx2; - void (* volatile f5)(float*, const signed char*, const signed char*, float, const float*, int, int, float) = qk_int8_gemm_row_avx2; - void (* volatile f6)(float*, const signed char*, const signed char*, const float*, const float*, int, int, int, float) = qk_int8_gemm_tiled_avx2; - void (* volatile f7)(float*, const float*, const signed char*, const float*, int, int, int) = decode_pv_gemv_int8_avx2; - void (* volatile f8)(float*, const float*, const signed char*, const float*, int, int) = pv_float_int8_gemm_row_avx2; - void (* volatile f9)(float*, float, const signed char*, int) = pv_float_int8_fma_block_avx2; - void (* volatile f10)(float*, const float*, const signed char*, const float*, int, int, int) = pv_float_int8_gemm_tile_avx2; - (void)f1; (void)f2; (void)f3; (void)f4; (void)f5; - (void)f6; (void)f7; (void)f8; (void)f9; (void)f10; + +void dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width) +{ + dynamic_quantize_blockwise_avx2_kernel(src, dst, scales, width); +} + +void dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width) +{ + dynamic_quantize_rowwise_avx2_kernel(src, dst, scale, width); +} + +int qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len) +{ + return qk_int8_dot_block_avx2_kernel(a, b, len); +} + +void decode_qk_dot_int8_avx2(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avx2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avx2(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avx2(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +void decode_pv_gemv_int8_avx2(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) +{ + decode_pv_gemv_int8_avx2_kernel(out, s, V, vscales, n_start, block_n, out_d); +} + +void pv_float_int8_gemm_row_avx2(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) +{ + pv_float_int8_gemm_row_avx2_kernel(out, p_row, V, vscales, n, out_d); } + +void pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len) +{ + pv_float_int8_fma_block_avx2_kernel(out, p_invscale, v, len); +} + +void pv_float_int8_gemm_tile_avx2(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim) +{ + pv_float_int8_gemm_tile_avx2_kernel(O, P, V, vscales, block_m, block_n, out_embed_dim); +} + #endif // __AVX2__ } // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx512vnni.cpp b/src/layer/x86/sdpa_x86_avx512vnni.cpp new file mode 100644 index 000000000000..2719d668ed73 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx512vnni.cpp @@ -0,0 +1,42 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVX512F__ && __AVX512VNNI__ + +void decode_qk_dot_int8_avx512vnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avx512_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avx512vnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx512vnni_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avx512vnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx512vnni_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVX512F__ && __AVX512VNNI__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avxvnni.cpp b/src/layer/x86/sdpa_x86_avxvnni.cpp new file mode 100644 index 000000000000..296cbc89a7e0 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avxvnni.cpp @@ -0,0 +1,41 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVXVNNI__ + +void decode_qk_dot_int8_avxvnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avxvnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + // fallback to avx2 kernel for row-wise gemm + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avxvnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + // fallback to avx2 kernel for tiled gemm + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVXVNNI__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avxvnniint8.cpp b/src/layer/x86/sdpa_x86_avxvnniint8.cpp new file mode 100644 index 000000000000..f858dceb08c5 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avxvnniint8.cpp @@ -0,0 +1,39 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __AVXVNNIINT8__ + +void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_avxvnni_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_avxvnniint8(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +} + +void qk_int8_gemm_tiled_avxvnniint8(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +} + +#endif // __AVXVNNIINT8__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_int8.h b/src/layer/x86/sdpa_x86_int8.h index 1ba9d19311f9..9969e901c824 100644 --- a/src/layer/x86/sdpa_x86_int8.h +++ b/src/layer/x86/sdpa_x86_int8.h @@ -6,7 +6,25 @@ static inline signed char float2int8(float v) return (signed char)int32; } -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avx512vnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avx512vnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avx512vnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avxvnniint8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avxvnniint8(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avxvnniint8(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_avxvnni(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_avxvnni(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_avxvnni(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ void dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width); void dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width); int qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len); @@ -19,7 +37,13 @@ void pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed cha void pv_float_int8_gemm_tile_avx2(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim); #endif -static void dynamic_quantize_blockwise_scalar(const float* src, signed char* dst, float* scales, int width) +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ +void decode_qk_dot_int8_xop(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale); +void qk_int8_gemm_row_xop(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale); +void qk_int8_gemm_tiled_xop(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale); +#endif + +static void dynamic_quantize_blockwise_scalar_kernel(const float* src, signed char* dst, float* scales, int width) { const int block_size = 32; int num_blocks = (width + block_size - 1) / block_size; @@ -42,7 +66,7 @@ static void dynamic_quantize_blockwise_scalar(const float* src, signed char* dst } #if __SSE2__ -static inline void dynamic_quantize_blockwise_sse2(const float* src, signed char* dst, float* scales, int width) +static inline void dynamic_quantize_blockwise_sse2_kernel(const float* src, signed char* dst, float* scales, int width) { const int block_size = 32; int num_blocks = (width + block_size - 1) / block_size; @@ -98,7 +122,7 @@ static inline void dynamic_quantize_blockwise_sse2(const float* src, signed char #endif // __SSE2__ #if __AVX2__ -inline void __attribute__((noinline)) dynamic_quantize_blockwise_avx2(const float* src, signed char* dst, float* scales, int width) +inline void dynamic_quantize_blockwise_avx2_kernel(const float* src, signed char* dst, float* scales, int width) { const int block_size = 32; int num_blocks = (width + block_size - 1) / block_size; @@ -172,7 +196,7 @@ inline void __attribute__((noinline)) dynamic_quantize_blockwise_avx2(const floa #endif // __AVX2__ #if __AVX512F__ -static inline void dynamic_quantize_blockwise_avx512(const float* src, signed char* dst, float* scales, int width) +static inline void dynamic_quantize_blockwise_avx512_kernel(const float* src, signed char* dst, float* scales, int width) { const int block_size = 32; int num_blocks = (width + block_size - 1) / block_size; @@ -220,10 +244,10 @@ static inline void dynamic_quantize_blockwise_avx512(const float* src, signed ch } #endif // __AVX512F__ -static void dynamic_quantize_blockwise_dispatch(const float* src, signed char* dst, float* scales, int width) +static void dynamic_quantize_blockwise(const float* src, signed char* dst, float* scales, int width) { #if __AVX512F__ - dynamic_quantize_blockwise_avx512(src, dst, scales, width); + dynamic_quantize_blockwise_avx512_kernel(src, dst, scales, width); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -235,14 +259,14 @@ static void dynamic_quantize_blockwise_dispatch(const float* src, signed char* d #if __AVX2__ dynamic_quantize_blockwise_avx2(src, dst, scales, width); #elif __SSE2__ - dynamic_quantize_blockwise_sse2(src, dst, scales, width); + dynamic_quantize_blockwise_sse2_kernel(src, dst, scales, width); #else - dynamic_quantize_blockwise_scalar(src, dst, scales, width); + dynamic_quantize_blockwise_scalar_kernel(src, dst, scales, width); #endif #endif } -static inline void dynamic_quantize_rowwise_scalar(const float* src, signed char* dst, float* scale, int width) +static inline void dynamic_quantize_rowwise_scalar_kernel(const float* src, signed char* dst, float* scale, int width) { float absmax = 0.f; for (int i = 0; i < width; i++) @@ -258,7 +282,7 @@ static inline void dynamic_quantize_rowwise_scalar(const float* src, signed char } #if __SSE2__ -static inline void dynamic_quantize_rowwise_sse2(const float* src, signed char* dst, float* scale, int width) +static inline void dynamic_quantize_rowwise_sse2_kernel(const float* src, signed char* dst, float* scale, int width) { __m128 sign_mask = _mm_set1_ps(-0.f); __m128 vmax = _mm_setzero_ps(); @@ -295,7 +319,7 @@ static inline void dynamic_quantize_rowwise_sse2(const float* src, signed char* #endif #if __AVX2__ -inline void __attribute__((noinline)) dynamic_quantize_rowwise_avx2(const float* src, signed char* dst, float* scale, int width) +inline void dynamic_quantize_rowwise_avx2_kernel(const float* src, signed char* dst, float* scale, int width) { __m256 sign_mask = _mm256_set1_ps(-0.f); __m256 vmax = _mm256_setzero_ps(); @@ -331,7 +355,7 @@ inline void __attribute__((noinline)) dynamic_quantize_rowwise_avx2(const float* #endif #if __AVX512F__ -static inline void dynamic_quantize_rowwise_avx512(const float* src, signed char* dst, float* scale, int width) +static inline void dynamic_quantize_rowwise_avx512_kernel(const float* src, signed char* dst, float* scale, int width) { __m512 sign_mask = _mm512_set1_ps(-0.f); __m512 vmax = _mm512_setzero_ps(); @@ -370,10 +394,10 @@ static inline void dynamic_quantize_rowwise_avx512(const float* src, signed char } #endif -static void dynamic_quantize_rowwise_dispatch(const float* src, signed char* dst, float* scale, int width) +static void dynamic_quantize_rowwise(const float* src, signed char* dst, float* scale, int width) { #if __AVX512F__ - dynamic_quantize_rowwise_avx512(src, dst, scale, width); + dynamic_quantize_rowwise_avx512_kernel(src, dst, scale, width); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -385,9 +409,9 @@ static void dynamic_quantize_rowwise_dispatch(const float* src, signed char* dst #if __AVX2__ dynamic_quantize_rowwise_avx2(src, dst, scale, width); #elif __SSE2__ - dynamic_quantize_rowwise_sse2(src, dst, scale, width); + dynamic_quantize_rowwise_sse2_kernel(src, dst, scale, width); #else - dynamic_quantize_rowwise_scalar(src, dst, scale, width); + dynamic_quantize_rowwise_scalar_kernel(src, dst, scale, width); #endif #endif } @@ -422,7 +446,7 @@ static inline void reciprocal_scales(float* scales, int num_blocks) // =================== Int8 SIMD Kernels =================== -static inline int qk_int8_dot_block_scalar(const signed char* a, const signed char* b, int len) +static inline int qk_int8_dot_block_scalar_kernel(const signed char* a, const signed char* b, int len) { int sum = 0; for (int i = 0; i < len; i++) @@ -451,7 +475,7 @@ static inline __m128i _mm_cvtepi8_epi32_sse2(__m128i x) #endif #if __SSE2__ -static inline int qk_int8_dot_block_sse2(const signed char* a, const signed char* b, int len) +static inline int qk_int8_dot_block_sse2_kernel(const signed char* a, const signed char* b, int len) { __m128i sum = _mm_setzero_si128(); int i = 0; @@ -473,8 +497,31 @@ static inline int qk_int8_dot_block_sse2(const signed char* a, const signed char } #endif // __SSE2__ +#if __XOP__ +static inline int qk_int8_dot_block_xop_kernel(const signed char* a, const signed char* b, int len) +{ + __m128i sum = _mm_setzero_si128(); + int i = 0; + for (; i + 15 < len; i += 16) + { + __m128i va = _mm_loadu_si128((const __m128i*)(a + i)); + __m128i vb = _mm_loadu_si128((const __m128i*)(b + i)); + __m128i va_lo = _mm_cvtepi8_epi16(va); + __m128i va_hi = _mm_cvtepi8_epi16(_mm_srli_si128(va, 8)); + __m128i vb_lo = _mm_cvtepi8_epi16(vb); + __m128i vb_hi = _mm_cvtepi8_epi16(_mm_srli_si128(vb, 8)); + sum = _mm_maccd_epi16(va_lo, vb_lo, sum); + sum = _mm_maccd_epi16(va_hi, vb_hi, sum); + } + int sum_tail = 0; + for (; i < len; i++) + sum_tail += a[i] * b[i]; + return _mm_reduce_add_epi32(sum) + sum_tail; +} +#endif // __XOP__ + #if __AVX2__ -inline int __attribute__((noinline)) qk_int8_dot_block_avx2(const signed char* a, const signed char* b, int len) +inline int qk_int8_dot_block_avx2_kernel(const signed char* a, const signed char* b, int len) { __m256i sum = _mm256_setzero_si256(); int i = 0; @@ -486,8 +533,8 @@ inline int __attribute__((noinline)) qk_int8_dot_block_avx2(const signed char* a __m256i va_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(va, 1)); __m256i vb_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(vb)); __m256i vb_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vb, 1)); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(va_lo, vb_lo)); - sum = _mm256_add_epi32(sum, _mm256_madd_epi16(va_hi, vb_hi)); + sum = _mm256_comp_dpwssd_epi32(sum, va_lo, vb_lo); + sum = _mm256_comp_dpwssd_epi32(sum, va_hi, vb_hi); } int sum_tail = 0; for (; i < len; i++) @@ -498,7 +545,7 @@ inline int __attribute__((noinline)) qk_int8_dot_block_avx2(const signed char* a #endif // __AVX2__ #if __AVX512F__ -static inline int qk_int8_dot_block_avx512(const signed char* a, const signed char* b, int len) +static inline int qk_int8_dot_block_avx512_kernel(const signed char* a, const signed char* b, int len) { __m512i sum = _mm512_setzero_si512(); int i = 0; @@ -508,7 +555,7 @@ static inline int qk_int8_dot_block_avx512(const signed char* a, const signed ch __m256i vb256 = _mm256_loadu_si256((const __m256i*)(b + i)); __m512i va = _mm512_cvtepi8_epi16(va256); __m512i vb = _mm512_cvtepi8_epi16(vb256); - sum = _mm512_add_epi32(sum, _mm512_madd_epi16(va, vb)); + sum = _mm512_comp_dpwssd_epi32(sum, va, vb); } int sum_tail = 0; for (; i < len; i++) @@ -522,7 +569,7 @@ static inline int qk_int8_dot_block_avx512(const signed char* a, const signed ch static inline int qk_int8_dot_block(const signed char* a, const signed char* b, int len) { #if __AVX512F__ - return qk_int8_dot_block_avx512(a, b, len); + return qk_int8_dot_block_avx512_kernel(a, b, len); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -533,16 +580,16 @@ static inline int qk_int8_dot_block(const signed char* a, const signed char* b, #if __AVX2__ return qk_int8_dot_block_avx2(a, b, len); #elif __SSE2__ - return qk_int8_dot_block_sse2(a, b, len); + return qk_int8_dot_block_sse2_kernel(a, b, len); #else - return qk_int8_dot_block_scalar(a, b, len); + return qk_int8_dot_block_scalar_kernel(a, b, len); #endif #endif } // ------------------- Decode QK Dot Int8 ------------------- -static inline void decode_qk_dot_int8_scalar(float* s, const signed char* q, +static inline void decode_qk_dot_int8_scalar_kernel(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) { @@ -550,22 +597,22 @@ static inline void decode_qk_dot_int8_scalar(float* s, const signed char* q, for (int j = 0; j < block_n; j++) { const signed char* kptr = K + (n_start + j) * d; - const float* ks = kscales + (n_start + j) * num_blocks; + const float* ks = kscales + (n_start + j); float sum = 0.f; for (int b = 0; b < num_blocks; b++) { int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_scalar(q + off, kptr + off, len); - sum += (float)block_sum / (qscales[b] * ks[b]); + int block_sum = qk_int8_dot_block_scalar_kernel(q + off, kptr + off, len); + sum += (float)block_sum / (qscales[0] * ks[0]); } s[j] = sum * scale; } } #if __SSE2__ -static inline void decode_qk_dot_int8_sse2(float* s, const signed char* q, +static inline void decode_qk_dot_int8_sse2_kernel(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) { @@ -575,26 +622,26 @@ static inline void decode_qk_dot_int8_sse2(float* s, const signed char* q, { const signed char* k0 = K + (n_start + j + 0) * d; const signed char* k1 = K + (n_start + j + 1) * d; - const float* ks0 = kscales + (n_start + j + 0) * num_blocks; - const float* ks1 = kscales + (n_start + j + 1) * num_blocks; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); float sum0 = 0.f, sum1 = 0.f; for (int b = 0; b < num_blocks; b++) { - float descale = qscales[b] * ks0[b]; + float descale = qscales[0] * ks0[0]; int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_sse2(q + off, k0 + off, len); + int block_sum = qk_int8_dot_block_sse2_kernel(q + off, k0 + off, len); sum0 += (float)block_sum * descale; } for (int b = 0; b < num_blocks; b++) { - float descale = qscales[b] * ks1[b]; + float descale = qscales[0] * ks1[0]; int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_sse2(q + off, k1 + off, len); + int block_sum = qk_int8_dot_block_sse2_kernel(q + off, k1 + off, len); sum1 += (float)block_sum * descale; } @@ -604,15 +651,15 @@ static inline void decode_qk_dot_int8_sse2(float* s, const signed char* q, for (; j < block_n; j++) { const signed char* kptr = K + (n_start + j) * d; - const float* ks = kscales + (n_start + j) * num_blocks; + const float* ks = kscales + (n_start + j); float sum = 0.f; for (int b = 0; b < num_blocks; b++) { int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_scalar(q + off, kptr + off, len); - float descale = qscales[b] * ks[b]; + int block_sum = qk_int8_dot_block_scalar_kernel(q + off, kptr + off, len); + float descale = qscales[0] * ks[0]; sum += (float)block_sum * descale; } s[j] = sum * scale; @@ -620,8 +667,8 @@ static inline void decode_qk_dot_int8_sse2(float* s, const signed char* q, } #endif // __SSE2__ -#if __AVX2__ -inline void __attribute__((noinline)) decode_qk_dot_int8_avx2(float* s, const signed char* q, +#if __XOP__ +static inline void decode_qk_dot_int8_xop_kernel(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) { @@ -631,8 +678,8 @@ inline void __attribute__((noinline)) decode_qk_dot_int8_avx2(float* s, const si { const signed char* k0 = K + (n_start + j + 0) * d; const signed char* k1 = K + (n_start + j + 1) * d; - const float* ks0 = kscales + (n_start + j + 0) * num_blocks; - const float* ks1 = kscales + (n_start + j + 1) * num_blocks; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); float sum0 = 0.f, sum1 = 0.f; for (int b = 0; b < num_blocks; b++) @@ -640,10 +687,10 @@ inline void __attribute__((noinline)) decode_qk_dot_int8_avx2(float* s, const si int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum0 = qk_int8_dot_block_avx2(q + off, k0 + off, len); - int block_sum1 = qk_int8_dot_block_avx2(q + off, k1 + off, len); - sum0 += (float)block_sum0 / (qscales[b] * ks0[b]); - sum1 += (float)block_sum1 / (qscales[b] * ks1[b]); + int block_sum0 = qk_int8_dot_block_xop_kernel(q + off, k0 + off, len); + int block_sum1 = qk_int8_dot_block_xop_kernel(q + off, k1 + off, len); + sum0 += (float)block_sum0 * qscales[0] * ks0[0]; + sum1 += (float)block_sum1 * qscales[0] * ks1[0]; } s[j + 0] = sum0 * scale; @@ -652,89 +699,333 @@ inline void __attribute__((noinline)) decode_qk_dot_int8_avx2(float* s, const si for (; j < block_n; j++) { const signed char* kptr = K + (n_start + j) * d; - const float* ks = kscales + (n_start + j) * num_blocks; + const float* ks = kscales + (n_start + j); float sum = 0.f; for (int b = 0; b < num_blocks; b++) { int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_avx2(q + off, kptr + off, len); - float descale = qscales[b] * ks[b]; - sum += (float)block_sum * descale; + int block_sum = qk_int8_dot_block_xop_kernel(q + off, kptr + off, len); + sum += (float)block_sum * qscales[0] * ks[0]; } s[j] = sum * scale; } } -#endif // __AVX2__ +#endif // __XOP__ -#if __AVX512F__ -static inline void decode_qk_dot_int8_avx512(float* s, const signed char* q, +#if __AVX2__ +inline void decode_qk_dot_int8_avx2_kernel(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) { const int num_blocks = (d + 31) / 32; int j = 0; - for (; j + 3 < block_n; j += 4) + for (; j + 1 < block_n; j += 2) { const signed char* k0 = K + (n_start + j + 0) * d; const signed char* k1 = K + (n_start + j + 1) * d; - const signed char* k2 = K + (n_start + j + 2) * d; - const signed char* k3 = K + (n_start + j + 3) * d; - const float* ks0 = kscales + (n_start + j + 0) * num_blocks; - const float* ks1 = kscales + (n_start + j + 1) * num_blocks; - const float* ks2 = kscales + (n_start + j + 2) * num_blocks; - const float* ks3 = kscales + (n_start + j + 3) * num_blocks; + const float* ks0 = kscales + (n_start + j + 0); + const float* ks1 = kscales + (n_start + j + 1); - float sum0 = 0.f, sum1 = 0.f, sum2 = 0.f, sum3 = 0.f; + float sum0 = 0.f, sum1 = 0.f; for (int b = 0; b < num_blocks; b++) { int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int bs; - bs = qk_int8_dot_block_avx512(q + off, k0 + off, len); - sum0 += (float)bs / (qscales[b] * ks0[b]); - bs = qk_int8_dot_block_avx512(q + off, k1 + off, len); - sum1 += (float)bs / (qscales[b] * ks1[b]); - bs = qk_int8_dot_block_avx512(q + off, k2 + off, len); - sum2 += (float)bs / (qscales[b] * ks2[b]); - bs = qk_int8_dot_block_avx512(q + off, k3 + off, len); - sum3 += (float)bs / (qscales[b] * ks3[b]); + int block_sum0 = qk_int8_dot_block_avx2_kernel(q + off, k0 + off, len); + int block_sum1 = qk_int8_dot_block_avx2_kernel(q + off, k1 + off, len); + sum0 += (float)block_sum0 / (qscales[0] * ks0[0]); + sum1 += (float)block_sum1 / (qscales[0] * ks1[0]); } s[j + 0] = sum0 * scale; s[j + 1] = sum1 * scale; - s[j + 2] = sum2 * scale; - s[j + 3] = sum3 * scale; } for (; j < block_n; j++) { const signed char* kptr = K + (n_start + j) * d; - const float* ks = kscales + (n_start + j) * num_blocks; + const float* ks = kscales + (n_start + j); float sum = 0.f; for (int b = 0; b < num_blocks; b++) { int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - int block_sum = qk_int8_dot_block_avx512(q + off, kptr + off, len); - float descale = qscales[b] * ks[b]; + int block_sum = qk_int8_dot_block_avx2_kernel(q + off, kptr + off, len); + float descale = qscales[0] * ks[0]; sum += (float)block_sum * descale; } s[j] = sum * scale; } } +#endif // __AVX2__ + +#if __AVX2__ +static inline int _mm256_reduce_add_epi32(__m256i v); +#endif + +#if __AVXVNNI__ +static inline void decode_qk_dot_int8_avxvnni_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const float qscale = qscales[0]; + const int num_blocks_32 = d / 32; + + // Precompute qsum for dpbusd compensation: sum((q+128)*k) = sum(q*k) + 128*sum(k) + int qsum = 0; + { + __m256i qsum_acc = _mm256_setzero_si256(); + const __m256i ones = _mm256_set1_epi8(1); + for (int b = 0; b < num_blocks_32; b++) + { + __m256i q_256 = _mm256_loadu_si256((const __m256i*)(q + b * 32)); + qsum_acc = _mm256_dpbusd_epi32(qsum_acc, ones, q_256); + } + qsum = _mm256_reduce_add_epi32(qsum_acc); + for (int i = num_blocks_32 * 32; i < d; i++) + qsum += q[i]; + } + + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const signed char* k2 = K + (n_start + j + 2) * d; + const signed char* k3 = K + (n_start + j + 3) * d; + + __m256i acc0 = _mm256_setzero_si256(); + __m256i acc1 = _mm256_setzero_si256(); + __m256i acc2 = _mm256_setzero_si256(); + __m256i acc3 = _mm256_setzero_si256(); + const __m256i offset128 = _mm256_set1_epi8(128); + + for (int b = 0; b < num_blocks_32; b++) + { + int off = b * 32; + __m256i q_u8 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i*)(q + off)), offset128); + + acc0 = _mm256_dpbusd_epi32(acc0, q_u8, _mm256_loadu_si256((const __m256i*)(k0 + off))); + acc1 = _mm256_dpbusd_epi32(acc1, q_u8, _mm256_loadu_si256((const __m256i*)(k1 + off))); + acc2 = _mm256_dpbusd_epi32(acc2, q_u8, _mm256_loadu_si256((const __m256i*)(k2 + off))); + acc3 = _mm256_dpbusd_epi32(acc3, q_u8, _mm256_loadu_si256((const __m256i*)(k3 + off))); + } + + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int i = num_blocks_32 * 32; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm256_reduce_add_epi32(acc0) - 128 * qsum + scalar0); + float sum1 = (float)(_mm256_reduce_add_epi32(acc1) - 128 * qsum + scalar1); + float sum2 = (float)(_mm256_reduce_add_epi32(acc2) - 128 * qsum + scalar2); + float sum3 = (float)(_mm256_reduce_add_epi32(acc3) - 128 * qsum + scalar3); + + s[j + 0] = sum0 * qscale * kscales[n_start + j + 0] * scale; + s[j + 1] = sum1 * qscale * kscales[n_start + j + 1] * scale; + s[j + 2] = sum2 * qscale * kscales[n_start + j + 2] * scale; + s[j + 3] = sum3 * qscale * kscales[n_start + j + 3] * scale; + } + + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; + __m256i acc = _mm256_setzero_si256(); + const __m256i offset128 = _mm256_set1_epi8(128); + for (int b = 0; b < num_blocks_32; b++) + { + int off = b * 32; + __m256i q_u8 = _mm256_add_epi8(_mm256_loadu_si256((const __m256i*)(q + off)), offset128); + acc = _mm256_dpbusd_epi32(acc, q_u8, _mm256_loadu_si256((const __m256i*)(kptr + off))); + } + int scalar = 0; + for (int i = num_blocks_32 * 32; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm256_reduce_add_epi32(acc) - 128 * qsum + scalar); + s[j] = sum * qscale * kscales[n_start + j] * scale; + } +} +#endif // __AVXVNNI__ + +#if __AVX512F__ +static inline void decode_qk_dot_int8_avx512_kernel(float* s, const signed char* q, + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) +{ + const float qscale = qscales[0]; + const int num_blocks_64 = d / 64; + +#if __AVX512VNNI__ + // Precompute qsum for dpbusd compensation: sum((q+128)*k) = sum(q*k) + 128*sum(k) + int qsum = 0; + { + __m512i qsum_acc = _mm512_setzero_si512(); + const __m512i ones = _mm512_set1_epi8(1); + for (int b = 0; b < num_blocks_64; b++) + { + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q + b * 64)); + qsum_acc = _mm512_dpbusd_epi32(qsum_acc, ones, q_512); + } + qsum = _mm512_reduce_add_epi32(qsum_acc); + for (int i = num_blocks_64 * 64; i < d; i++) + qsum += q[i]; + } +#endif + + int j = 0; + for (; j + 3 < block_n; j += 4) + { + const signed char* k0 = K + (n_start + j + 0) * d; + const signed char* k1 = K + (n_start + j + 1) * d; + const signed char* k2 = K + (n_start + j + 2) * d; + const signed char* k3 = K + (n_start + j + 3) * d; + +#if __AVX512VNNI__ + __m512i acc0 = _mm512_setzero_si512(); + __m512i acc1 = _mm512_setzero_si512(); + __m512i acc2 = _mm512_setzero_si512(); + __m512i acc3 = _mm512_setzero_si512(); + const __m512i offset128 = _mm512_set1_epi8(128); + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_u8 = _mm512_add_epi8(_mm512_loadu_si512((const __m512i*)(q + off)), offset128); + + acc0 = _mm512_dpbusd_epi32(acc0, q_u8, _mm512_loadu_si512((const __m512i*)(k0 + off))); + acc1 = _mm512_dpbusd_epi32(acc1, q_u8, _mm512_loadu_si512((const __m512i*)(k1 + off))); + acc2 = _mm512_dpbusd_epi32(acc2, q_u8, _mm512_loadu_si512((const __m512i*)(k2 + off))); + acc3 = _mm512_dpbusd_epi32(acc3, q_u8, _mm512_loadu_si512((const __m512i*)(k3 + off))); + } + + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + for (int i = num_blocks_64 * 64; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum + scalar0); + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum + scalar1); + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum + scalar2); + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum + scalar3); +#else + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m256i q_32 = _mm256_loadu_si256((const __m256i*)(q + off)); + __m512i q_512 = _mm512_cvtepi8_epi16(q_32); + + acc0 = _mm512_comp_dpwssd_epi32(acc0, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k0 + off)))); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k1 + off)))); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k2 + off)))); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q_512, _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(k3 + off)))); + } + for (int i = num_blocks_64 * 64; i < d; i++) + { + scalar0 += q[i] * k0[i]; + scalar1 += q[i] * k1[i]; + scalar2 += q[i] * k2[i]; + scalar3 += q[i] * k3[i]; + } + + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) + scalar0); + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) + scalar1); + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) + scalar2); + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) + scalar3); +#endif + s[j + 0] = sum0 * qscale * kscales[n_start + j + 0] * scale; + s[j + 1] = sum1 * qscale * kscales[n_start + j + 1] * scale; + s[j + 2] = sum2 * qscale * kscales[n_start + j + 2] * scale; + s[j + 3] = sum3 * qscale * kscales[n_start + j + 3] * scale; + } + + for (; j < block_n; j++) + { + const signed char* kptr = K + (n_start + j) * d; +#if __AVX512VNNI__ + __m512i acc = _mm512_setzero_si512(); + const __m512i offset128 = _mm512_set1_epi8(128); + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_u8 = _mm512_add_epi8(_mm512_loadu_si512((const __m512i*)(q + off)), offset128); + acc = _mm512_dpbusd_epi32(acc, q_u8, _mm512_loadu_si512((const __m512i*)(kptr + off))); + } + int scalar = 0; + for (int i = num_blocks_64 * 64; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm512_reduce_add_epi32(acc) - 128 * qsum + scalar); +#else + __m512i acc = _mm512_setzero_epi32(); + int scalar = 0; + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + acc = _mm512_comp_dpwssd_epi32(acc, + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(q + off))), + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(kptr + off)))); + } + for (int i = num_blocks_64 * 64; i < d; i++) + scalar += q[i] * kptr[i]; + float sum = (float)(_mm512_reduce_add_epi32(acc) + scalar); +#endif + s[j] = sum * qscale * kscales[n_start + j] * scale; + } +} #endif // __AVX512F__ -static void decode_qk_dot_int8_dispatch(float* s, const signed char* q, +static void decode_qk_dot_int8(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) { #if __AVX512F__ - decode_qk_dot_int8_avx512(s, q, K, qscales, kscales, n_start, block_n, d, scale); +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + decode_qk_dot_int8_avx512vnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif + decode_qk_dot_int8_avx512_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); #else -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + decode_qk_dot_int8_avx512vnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + decode_qk_dot_int8_avxvnniint8(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + decode_qk_dot_int8_avxvnni(s, q, K, qscales, kscales, n_start, block_n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { decode_qk_dot_int8_avx2(s, q, K, qscales, kscales, n_start, block_n, d, scale); @@ -742,18 +1033,18 @@ static void decode_qk_dot_int8_dispatch(float* s, const signed char* q, } #endif #if __AVX2__ - decode_qk_dot_int8_avx2(s, q, K, qscales, kscales, n_start, block_n, d, scale); + decode_qk_dot_int8_avx2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); #elif __SSE2__ - decode_qk_dot_int8_sse2(s, q, K, qscales, kscales, n_start, block_n, d, scale); + decode_qk_dot_int8_sse2_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); #else - decode_qk_dot_int8_scalar(s, q, K, qscales, kscales, n_start, block_n, d, scale); + decode_qk_dot_int8_scalar_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); #endif #endif } // ------------------- Prefill QK Int8 GEMM Row-wise ------------------- -static inline void qk_int8_gemm_row_scalar(float* s_row, +static inline void qk_int8_gemm_row_scalar_kernel(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) { @@ -767,14 +1058,14 @@ static inline void qk_int8_gemm_row_scalar(float* s_row, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum += qk_int8_dot_block_scalar(q_row + off, kptr + off, len); + sum += qk_int8_dot_block_scalar_kernel(q_row + off, kptr + off, len); } s_row[j] = (float)sum * qscale * kscales[j] * scale; } } #if __SSE2__ -static inline void qk_int8_gemm_row_sse2(float* s_row, +static inline void qk_int8_gemm_row_sse2_kernel(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) { @@ -791,8 +1082,8 @@ static inline void qk_int8_gemm_row_sse2(float* s_row, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum0 += qk_int8_dot_block_sse2(q_row + off, k0 + off, len); - sum1 += qk_int8_dot_block_sse2(q_row + off, k1 + off, len); + sum0 += qk_int8_dot_block_sse2_kernel(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_sse2_kernel(q_row + off, k1 + off, len); } s_row[j] = (float)sum0 * qscale * kscales[j] * scale; s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; @@ -806,7 +1097,7 @@ static inline void qk_int8_gemm_row_sse2(float* s_row, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum += qk_int8_dot_block_scalar(q_row + off, kptr + off, len); + sum += qk_int8_dot_block_scalar_kernel(q_row + off, kptr + off, len); } s_row[j] = (float)sum * qscale * kscales[j] * scale; } @@ -814,7 +1105,7 @@ static inline void qk_int8_gemm_row_sse2(float* s_row, #endif // __SSE2__ #if __AVX2__ -inline void __attribute__((noinline)) qk_int8_gemm_row_avx2(float* s_row, +inline void qk_int8_gemm_row_avx2_kernel(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) { @@ -831,8 +1122,8 @@ inline void __attribute__((noinline)) qk_int8_gemm_row_avx2(float* s_row, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum0 += qk_int8_dot_block_avx2(q_row + off, k0 + off, len); - sum1 += qk_int8_dot_block_avx2(q_row + off, k1 + off, len); + sum0 += qk_int8_dot_block_avx2_kernel(q_row + off, k0 + off, len); + sum1 += qk_int8_dot_block_avx2_kernel(q_row + off, k1 + off, len); } s_row[j] = (float)sum0 * qscale * kscales[j] * scale; s_row[j + 1] = (float)sum1 * qscale * kscales[j + 1] * scale; @@ -846,7 +1137,7 @@ inline void __attribute__((noinline)) qk_int8_gemm_row_avx2(float* s_row, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum += qk_int8_dot_block_avx2(q_row + off, kptr + off, len); + sum += qk_int8_dot_block_avx2_kernel(q_row + off, kptr + off, len); } s_row[j] = (float)sum * qscale * kscales[j] * scale; } @@ -854,7 +1145,7 @@ inline void __attribute__((noinline)) qk_int8_gemm_row_avx2(float* s_row, #endif // __AVX2__ #if __AVX512F__ -static inline void qk_int8_gemm_row_avx512(float* s_row, +static inline void qk_int8_gemm_row_avx512_kernel(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) { @@ -891,21 +1182,21 @@ static inline void qk_int8_gemm_row_avx512(float* s_row, __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); - acc0 = _mm512_add_epi32(acc0, _mm512_madd_epi16(q_512, k0_512)); - acc1 = _mm512_add_epi32(acc1, _mm512_madd_epi16(q_512, k1_512)); - acc2 = _mm512_add_epi32(acc2, _mm512_madd_epi16(q_512, k2_512)); - acc3 = _mm512_add_epi32(acc3, _mm512_madd_epi16(q_512, k3_512)); + acc0 = _mm512_comp_dpwssd_epi32(acc0, q_512, k0_512); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q_512, k1_512); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q_512, k2_512); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q_512, k3_512); } else { int bs; - bs = qk_int8_dot_block_avx512(q_row + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k0 + off, len); scalar0 += bs; - bs = qk_int8_dot_block_avx512(q_row + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k1 + off, len); scalar1 += bs; - bs = qk_int8_dot_block_avx512(q_row + off, k2 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k2 + off, len); scalar2 += bs; - bs = qk_int8_dot_block_avx512(q_row + off, k3 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q_row + off, k3 + off, len); scalar3 += bs; } } @@ -938,11 +1229,11 @@ static inline void qk_int8_gemm_row_avx512(float* s_row, __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); __m512i q_512 = _mm512_cvtepi8_epi16(q_32); __m512i k_512 = _mm512_cvtepi8_epi16(k_32); - acc = _mm512_add_epi32(acc, _mm512_madd_epi16(q_512, k_512)); + acc = _mm512_comp_dpwssd_epi32(acc, q_512, k_512); } else { - scalar_sum += qk_int8_dot_block_avx512(q_row + off, kptr + off, len); + scalar_sum += qk_int8_dot_block_avx512_kernel(q_row + off, kptr + off, len); } } float sum = (float)_mm512_reduce_add_epi32(acc) + (float)scalar_sum; @@ -951,14 +1242,150 @@ static inline void qk_int8_gemm_row_avx512(float* s_row, } #endif // __AVX512F__ -static void qk_int8_gemm_row_dispatch(float* s_row, +#if __AVX512VNNI__ +static inline void qk_int8_gemm_row_avx512vnni_kernel(float* s_row, + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) +{ + const int num_blocks_64 = d / 64; + int qsum_64byte = 0; + if (num_blocks_64 > 0) + { + __m512i ones = _mm512_set1_epi8(1); + __m512i sum_acc = _mm512_setzero_epi32(); + for (int b = 0; b < num_blocks_64; b++) + { + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + b * 64)); + sum_acc = _mm512_dpbusd_epi32(sum_acc, ones, q_512); + } + qsum_64byte = _mm512_reduce_add_epi32(sum_acc); + } + + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + off)); + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k2_512 = _mm512_loadu_si512((const __m512i*)(k2 + off)); + __m512i k3_512 = _mm512_loadu_si512((const __m512i*)(k3 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + __m512i k2_u8 = _mm512_add_epi8(k2_512, _mm512_set1_epi8(128)); + __m512i k3_u8 = _mm512_add_epi8(k3_512, _mm512_set1_epi8(128)); + + acc0 = _mm512_dpbusd_epi32(acc0, k0_u8, q_512); + acc1 = _mm512_dpbusd_epi32(acc1, k1_u8, q_512); + acc2 = _mm512_dpbusd_epi32(acc2, k2_u8, q_512); + acc3 = _mm512_dpbusd_epi32(acc3, k3_u8, q_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar0 += q_row[k] * k0[k]; + scalar1 += q_row[k] * k1[k]; + scalar2 += q_row[k] * k2[k]; + scalar3 += q_row[k] * k3[k]; + } + } + + float descale0 = qscale * kscales[j]; + float descale1 = qscale * kscales[j + 1]; + float descale2 = qscale * kscales[j + 2]; + float descale3 = qscale * kscales[j + 3]; + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum_64byte) + (float)scalar0; + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum_64byte) + (float)scalar1; + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum_64byte) + (float)scalar2; + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum_64byte) + (float)scalar3; + s_row[j] = sum0 * descale0 * scale; + s_row[j + 1] = sum1 * descale1 * scale; + s_row[j + 2] = sum2 * descale2 * scale; + s_row[j + 3] = sum3 * descale3 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + __m512i acc = _mm512_setzero_epi32(); + int scalar_sum = 0; + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q_512 = _mm512_loadu_si512((const __m512i*)(q_row + off)); + __m512i k_512 = _mm512_loadu_si512((const __m512i*)(kptr + off)); + __m512i k_u8 = _mm512_add_epi8(k_512, _mm512_set1_epi8(128)); + acc = _mm512_dpbusd_epi32(acc, k_u8, q_512); + } + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar_sum += q_row[k] * kptr[k]; + } + } + float sum = (float)(_mm512_reduce_add_epi32(acc) - 128 * qsum_64byte) + (float)scalar_sum; + s_row[j] = sum * qscale * kscales[j] * scale; + } +} +#endif // __AVX512VNNI__ + +static void qk_int8_gemm_row(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) { #if __AVX512F__ - qk_int8_gemm_row_avx512(s_row, q_row, K, qscale, kscales, n, d, scale); +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_row_avx512vnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if __AVX512VNNI__ + qk_int8_gemm_row_avx512vnni_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); #else -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + qk_int8_gemm_row_avx512_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); +#endif +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_row_avx512vnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + qk_int8_gemm_row_avxvnniint8(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + qk_int8_gemm_row_avxvnni(s_row, q_row, K, qscale, kscales, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { qk_int8_gemm_row_avx2(s_row, q_row, K, qscale, kscales, n, d, scale); @@ -966,18 +1393,18 @@ static void qk_int8_gemm_row_dispatch(float* s_row, } #endif #if __AVX2__ - qk_int8_gemm_row_avx2(s_row, q_row, K, qscale, kscales, n, d, scale); + qk_int8_gemm_row_avx2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); #elif __SSE2__ - qk_int8_gemm_row_sse2(s_row, q_row, K, qscale, kscales, n, d, scale); + qk_int8_gemm_row_sse2_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); #else - qk_int8_gemm_row_scalar(s_row, q_row, K, qscale, kscales, n, d, scale); + qk_int8_gemm_row_scalar_kernel(s_row, q_row, K, qscale, kscales, n, d, scale); #endif #endif } // ------------------- Tiled QK Int8 GEMM (M-tiling) ------------------- -static inline void qk_int8_gemm_tiled_scalar(float* S, +static inline void qk_int8_gemm_tiled_scalar_kernel(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) @@ -999,8 +1426,8 @@ static inline void qk_int8_gemm_tiled_scalar(float* S, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum0 += qk_int8_dot_block_scalar(q0 + off, kptr + off, len); - sum1 += qk_int8_dot_block_scalar(q1 + off, kptr + off, len); + sum0 += qk_int8_dot_block_scalar_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_scalar_kernel(q1 + off, kptr + off, len); } S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; @@ -1008,12 +1435,12 @@ static inline void qk_int8_gemm_tiled_scalar(float* S, } for (; i < m; i++) { - qk_int8_gemm_row_scalar(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + qk_int8_gemm_row_scalar_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); } } #if __SSE2__ -static inline void qk_int8_gemm_tiled_sse2(float* S, +static inline void qk_int8_gemm_tiled_sse2_kernel(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) @@ -1038,13 +1465,13 @@ static inline void qk_int8_gemm_tiled_sse2(float* S, int len = std::min(32, d - off); if (len <= 0) continue; int bs; - bs = qk_int8_dot_block_sse2(q0 + off, k0 + off, len); + bs = qk_int8_dot_block_sse2_kernel(q0 + off, k0 + off, len); sum00 += bs; - bs = qk_int8_dot_block_sse2(q0 + off, k1 + off, len); + bs = qk_int8_dot_block_sse2_kernel(q0 + off, k1 + off, len); sum01 += bs; - bs = qk_int8_dot_block_sse2(q1 + off, k0 + off, len); + bs = qk_int8_dot_block_sse2_kernel(q1 + off, k0 + off, len); sum10 += bs; - bs = qk_int8_dot_block_sse2(q1 + off, k1 + off, len); + bs = qk_int8_dot_block_sse2_kernel(q1 + off, k1 + off, len); sum11 += bs; } S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; @@ -1061,8 +1488,8 @@ static inline void qk_int8_gemm_tiled_sse2(float* S, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum0 += qk_int8_dot_block_sse2(q0 + off, kptr + off, len); - sum1 += qk_int8_dot_block_sse2(q1 + off, kptr + off, len); + sum0 += qk_int8_dot_block_sse2_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_sse2_kernel(q1 + off, kptr + off, len); } S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; @@ -1070,7 +1497,7 @@ static inline void qk_int8_gemm_tiled_sse2(float* S, } for (; i < m; i++) { - qk_int8_gemm_row_sse2(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + qk_int8_gemm_row_sse2_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); } } #endif // __SSE2__ @@ -1100,7 +1527,7 @@ static inline int _mm512_reduce_add_epi32(__m512i v) } #endif // __AVX512F__ -inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, +inline void qk_int8_gemm_tiled_avx2_kernel(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) @@ -1149,53 +1576,53 @@ inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, __m256i k0_32 = _mm256_loadu_si256((const __m256i*)(k0 + off)); __m256i k0_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k0_32)); __m256i k0_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k0_32, 1)); - acc00 = _mm256_add_epi32(acc00, _mm256_madd_epi16(q0_lo, k0_lo)); - acc00 = _mm256_add_epi32(acc00, _mm256_madd_epi16(q0_hi, k0_hi)); - acc10 = _mm256_add_epi32(acc10, _mm256_madd_epi16(q1_lo, k0_lo)); - acc10 = _mm256_add_epi32(acc10, _mm256_madd_epi16(q1_hi, k0_hi)); + acc00 = _mm256_comp_dpwssd_epi32(acc00, q0_lo, k0_lo); + acc00 = _mm256_comp_dpwssd_epi32(acc00, q0_hi, k0_hi); + acc10 = _mm256_comp_dpwssd_epi32(acc10, q1_lo, k0_lo); + acc10 = _mm256_comp_dpwssd_epi32(acc10, q1_hi, k0_hi); __m256i k1_32 = _mm256_loadu_si256((const __m256i*)(k1 + off)); __m256i k1_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k1_32)); __m256i k1_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k1_32, 1)); - acc01 = _mm256_add_epi32(acc01, _mm256_madd_epi16(q0_lo, k1_lo)); - acc01 = _mm256_add_epi32(acc01, _mm256_madd_epi16(q0_hi, k1_hi)); - acc11 = _mm256_add_epi32(acc11, _mm256_madd_epi16(q1_lo, k1_lo)); - acc11 = _mm256_add_epi32(acc11, _mm256_madd_epi16(q1_hi, k1_hi)); + acc01 = _mm256_comp_dpwssd_epi32(acc01, q0_lo, k1_lo); + acc01 = _mm256_comp_dpwssd_epi32(acc01, q0_hi, k1_hi); + acc11 = _mm256_comp_dpwssd_epi32(acc11, q1_lo, k1_lo); + acc11 = _mm256_comp_dpwssd_epi32(acc11, q1_hi, k1_hi); __m256i k2_32 = _mm256_loadu_si256((const __m256i*)(k2 + off)); __m256i k2_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k2_32)); __m256i k2_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k2_32, 1)); - acc02 = _mm256_add_epi32(acc02, _mm256_madd_epi16(q0_lo, k2_lo)); - acc02 = _mm256_add_epi32(acc02, _mm256_madd_epi16(q0_hi, k2_hi)); - acc12 = _mm256_add_epi32(acc12, _mm256_madd_epi16(q1_lo, k2_lo)); - acc12 = _mm256_add_epi32(acc12, _mm256_madd_epi16(q1_hi, k2_hi)); + acc02 = _mm256_comp_dpwssd_epi32(acc02, q0_lo, k2_lo); + acc02 = _mm256_comp_dpwssd_epi32(acc02, q0_hi, k2_hi); + acc12 = _mm256_comp_dpwssd_epi32(acc12, q1_lo, k2_lo); + acc12 = _mm256_comp_dpwssd_epi32(acc12, q1_hi, k2_hi); __m256i k3_32 = _mm256_loadu_si256((const __m256i*)(k3 + off)); __m256i k3_lo = _mm256_cvtepi8_epi16(_mm256_castsi256_si128(k3_32)); __m256i k3_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(k3_32, 1)); - acc03 = _mm256_add_epi32(acc03, _mm256_madd_epi16(q0_lo, k3_lo)); - acc03 = _mm256_add_epi32(acc03, _mm256_madd_epi16(q0_hi, k3_hi)); - acc13 = _mm256_add_epi32(acc13, _mm256_madd_epi16(q1_lo, k3_lo)); - acc13 = _mm256_add_epi32(acc13, _mm256_madd_epi16(q1_hi, k3_hi)); + acc03 = _mm256_comp_dpwssd_epi32(acc03, q0_lo, k3_lo); + acc03 = _mm256_comp_dpwssd_epi32(acc03, q0_hi, k3_hi); + acc13 = _mm256_comp_dpwssd_epi32(acc13, q1_lo, k3_lo); + acc13 = _mm256_comp_dpwssd_epi32(acc13, q1_hi, k3_hi); } else { int bs; - bs = qk_int8_dot_block_avx2(q0 + off, k0 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k0 + off, len); scalar00 += bs; - bs = qk_int8_dot_block_avx2(q0 + off, k1 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k1 + off, len); scalar01 += bs; - bs = qk_int8_dot_block_avx2(q0 + off, k2 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k2 + off, len); scalar02 += bs; - bs = qk_int8_dot_block_avx2(q0 + off, k3 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k3 + off, len); scalar03 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k0 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k0 + off, len); scalar10 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k1 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k1 + off, len); scalar11 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k2 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k2 + off, len); scalar12 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k3 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k3 + off, len); scalar13 += bs; } } @@ -1227,13 +1654,13 @@ inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, int len = std::min(32, d - off); if (len <= 0) continue; int bs; - bs = qk_int8_dot_block_avx2(q0 + off, k0 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k0 + off, len); sum00 += bs; - bs = qk_int8_dot_block_avx2(q0 + off, k1 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q0 + off, k1 + off, len); sum01 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k0 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k0 + off, len); sum10 += bs; - bs = qk_int8_dot_block_avx2(q1 + off, k1 + off, len); + bs = qk_int8_dot_block_avx2_kernel(q1 + off, k1 + off, len); sum11 += bs; } S[(i + 0) * n + j] = (float)sum00 * qs0 * kscales[j] * scale; @@ -1250,8 +1677,8 @@ inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, int off = b * 32; int len = std::min(32, d - off); if (len <= 0) continue; - sum0 += qk_int8_dot_block_avx2(q0 + off, kptr + off, len); - sum1 += qk_int8_dot_block_avx2(q1 + off, kptr + off, len); + sum0 += qk_int8_dot_block_avx2_kernel(q0 + off, kptr + off, len); + sum1 += qk_int8_dot_block_avx2_kernel(q1 + off, kptr + off, len); } S[(i + 0) * n + j] = (float)sum0 * qs0 * kscales[j] * scale; S[(i + 1) * n + j] = (float)sum1 * qs1 * kscales[j] * scale; @@ -1259,13 +1686,13 @@ inline void __attribute__((noinline)) qk_int8_gemm_tiled_avx2(float* S, } for (; i < m; i++) { - qk_int8_gemm_row_avx2(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + qk_int8_gemm_row_avx2_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); } } #endif // __AVX2__ #if __AVX512F__ -static inline void qk_int8_gemm_tiled_avx512(float* S, +static inline void qk_int8_gemm_tiled_avx512_kernel(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) @@ -1336,57 +1763,57 @@ static inline void qk_int8_gemm_tiled_avx512(float* S, __m512i k2_512 = _mm512_cvtepi8_epi16(k2_32); __m512i k3_512 = _mm512_cvtepi8_epi16(k3_32); - acc00 = _mm512_add_epi32(acc00, _mm512_madd_epi16(q0_512, k0_512)); - acc01 = _mm512_add_epi32(acc01, _mm512_madd_epi16(q0_512, k1_512)); - acc02 = _mm512_add_epi32(acc02, _mm512_madd_epi16(q0_512, k2_512)); - acc03 = _mm512_add_epi32(acc03, _mm512_madd_epi16(q0_512, k3_512)); - acc10 = _mm512_add_epi32(acc10, _mm512_madd_epi16(q1_512, k0_512)); - acc11 = _mm512_add_epi32(acc11, _mm512_madd_epi16(q1_512, k1_512)); - acc12 = _mm512_add_epi32(acc12, _mm512_madd_epi16(q1_512, k2_512)); - acc13 = _mm512_add_epi32(acc13, _mm512_madd_epi16(q1_512, k3_512)); - acc20 = _mm512_add_epi32(acc20, _mm512_madd_epi16(q2_512, k0_512)); - acc21 = _mm512_add_epi32(acc21, _mm512_madd_epi16(q2_512, k1_512)); - acc22 = _mm512_add_epi32(acc22, _mm512_madd_epi16(q2_512, k2_512)); - acc23 = _mm512_add_epi32(acc23, _mm512_madd_epi16(q2_512, k3_512)); - acc30 = _mm512_add_epi32(acc30, _mm512_madd_epi16(q3_512, k0_512)); - acc31 = _mm512_add_epi32(acc31, _mm512_madd_epi16(q3_512, k1_512)); - acc32 = _mm512_add_epi32(acc32, _mm512_madd_epi16(q3_512, k2_512)); - acc33 = _mm512_add_epi32(acc33, _mm512_madd_epi16(q3_512, k3_512)); + acc00 = _mm512_comp_dpwssd_epi32(acc00, q0_512, k0_512); + acc01 = _mm512_comp_dpwssd_epi32(acc01, q0_512, k1_512); + acc02 = _mm512_comp_dpwssd_epi32(acc02, q0_512, k2_512); + acc03 = _mm512_comp_dpwssd_epi32(acc03, q0_512, k3_512); + acc10 = _mm512_comp_dpwssd_epi32(acc10, q1_512, k0_512); + acc11 = _mm512_comp_dpwssd_epi32(acc11, q1_512, k1_512); + acc12 = _mm512_comp_dpwssd_epi32(acc12, q1_512, k2_512); + acc13 = _mm512_comp_dpwssd_epi32(acc13, q1_512, k3_512); + acc20 = _mm512_comp_dpwssd_epi32(acc20, q2_512, k0_512); + acc21 = _mm512_comp_dpwssd_epi32(acc21, q2_512, k1_512); + acc22 = _mm512_comp_dpwssd_epi32(acc22, q2_512, k2_512); + acc23 = _mm512_comp_dpwssd_epi32(acc23, q2_512, k3_512); + acc30 = _mm512_comp_dpwssd_epi32(acc30, q3_512, k0_512); + acc31 = _mm512_comp_dpwssd_epi32(acc31, q3_512, k1_512); + acc32 = _mm512_comp_dpwssd_epi32(acc32, q3_512, k2_512); + acc33 = _mm512_comp_dpwssd_epi32(acc33, q3_512, k3_512); } else { int bs; - bs = qk_int8_dot_block_avx512(q0 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k0 + off, len); scalar00 += bs; - bs = qk_int8_dot_block_avx512(q0 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k1 + off, len); scalar01 += bs; - bs = qk_int8_dot_block_avx512(q0 + off, k2 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k2 + off, len); scalar02 += bs; - bs = qk_int8_dot_block_avx512(q0 + off, k3 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k3 + off, len); scalar03 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k0 + off, len); scalar10 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k1 + off, len); scalar11 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k2 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k2 + off, len); scalar12 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k3 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k3 + off, len); scalar13 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k0 + off, len); scalar20 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k1 + off, len); scalar21 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k2 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k2 + off, len); scalar22 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k3 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k3 + off, len); scalar23 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k0 + off, len); scalar30 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k1 + off, len); scalar31 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k2 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k2 + off, len); scalar32 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k3 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k3 + off, len); scalar33 += bs; } } @@ -1476,33 +1903,33 @@ static inline void qk_int8_gemm_tiled_avx512(float* S, __m512i k0_512 = _mm512_cvtepi8_epi16(k0_32); __m512i k1_512 = _mm512_cvtepi8_epi16(k1_32); - acc00 = _mm512_add_epi32(acc00, _mm512_madd_epi16(q0_512, k0_512)); - acc01 = _mm512_add_epi32(acc01, _mm512_madd_epi16(q0_512, k1_512)); - acc10 = _mm512_add_epi32(acc10, _mm512_madd_epi16(q1_512, k0_512)); - acc11 = _mm512_add_epi32(acc11, _mm512_madd_epi16(q1_512, k1_512)); - acc20 = _mm512_add_epi32(acc20, _mm512_madd_epi16(q2_512, k0_512)); - acc21 = _mm512_add_epi32(acc21, _mm512_madd_epi16(q2_512, k1_512)); - acc30 = _mm512_add_epi32(acc30, _mm512_madd_epi16(q3_512, k0_512)); - acc31 = _mm512_add_epi32(acc31, _mm512_madd_epi16(q3_512, k1_512)); + acc00 = _mm512_comp_dpwssd_epi32(acc00, q0_512, k0_512); + acc01 = _mm512_comp_dpwssd_epi32(acc01, q0_512, k1_512); + acc10 = _mm512_comp_dpwssd_epi32(acc10, q1_512, k0_512); + acc11 = _mm512_comp_dpwssd_epi32(acc11, q1_512, k1_512); + acc20 = _mm512_comp_dpwssd_epi32(acc20, q2_512, k0_512); + acc21 = _mm512_comp_dpwssd_epi32(acc21, q2_512, k1_512); + acc30 = _mm512_comp_dpwssd_epi32(acc30, q3_512, k0_512); + acc31 = _mm512_comp_dpwssd_epi32(acc31, q3_512, k1_512); } else { int bs; - bs = qk_int8_dot_block_avx512(q0 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k0 + off, len); scalar00 += bs; - bs = qk_int8_dot_block_avx512(q0 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, k1 + off, len); scalar01 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k0 + off, len); scalar10 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, k1 + off, len); scalar11 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k0 + off, len); scalar20 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, k1 + off, len); scalar21 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k0 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k0 + off, len); scalar30 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, k1 + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, k1 + off, len); scalar31 += bs; } } @@ -1560,21 +1987,21 @@ static inline void qk_int8_gemm_tiled_avx512(float* S, __m256i k_32 = _mm256_loadu_si256((const __m256i*)(kptr + off)); __m512i k_512 = _mm512_cvtepi8_epi16(k_32); - acc0 = _mm512_add_epi32(acc0, _mm512_madd_epi16(q0_512, k_512)); - acc1 = _mm512_add_epi32(acc1, _mm512_madd_epi16(q1_512, k_512)); - acc2 = _mm512_add_epi32(acc2, _mm512_madd_epi16(q2_512, k_512)); - acc3 = _mm512_add_epi32(acc3, _mm512_madd_epi16(q3_512, k_512)); + acc0 = _mm512_comp_dpwssd_epi32(acc0, q0_512, k_512); + acc1 = _mm512_comp_dpwssd_epi32(acc1, q1_512, k_512); + acc2 = _mm512_comp_dpwssd_epi32(acc2, q2_512, k_512); + acc3 = _mm512_comp_dpwssd_epi32(acc3, q3_512, k_512); } else { int bs; - bs = qk_int8_dot_block_avx512(q0 + off, kptr + off, len); + bs = qk_int8_dot_block_avx512_kernel(q0 + off, kptr + off, len); scalar0 += bs; - bs = qk_int8_dot_block_avx512(q1 + off, kptr + off, len); + bs = qk_int8_dot_block_avx512_kernel(q1 + off, kptr + off, len); scalar1 += bs; - bs = qk_int8_dot_block_avx512(q2 + off, kptr + off, len); + bs = qk_int8_dot_block_avx512_kernel(q2 + off, kptr + off, len); scalar2 += bs; - bs = qk_int8_dot_block_avx512(q3 + off, kptr + off, len); + bs = qk_int8_dot_block_avx512_kernel(q3 + off, kptr + off, len); scalar3 += bs; } } @@ -1591,20 +2018,367 @@ static inline void qk_int8_gemm_tiled_avx512(float* S, } for (; i < m; i++) { - qk_int8_gemm_row_avx512(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + qk_int8_gemm_row_avx512_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); } } #endif // __AVX512F__ -static void qk_int8_gemm_tiled_dispatch(float* S, +#if __AVX512VNNI__ +static inline void qk_int8_gemm_tiled_avx512vnni_kernel(float* S, + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) +{ + const int num_blocks_64 = d / 64; + int i = 0; + for (; i + 4 <= m; i += 4) + { + const signed char* q0 = Q + (i + 0) * d; + const signed char* q1 = Q + (i + 1) * d; + const signed char* q2 = Q + (i + 2) * d; + const signed char* q3 = Q + (i + 3) * d; + float qs0 = qscales[i + 0]; + float qs1 = qscales[i + 1]; + float qs2 = qscales[i + 2]; + float qs3 = qscales[i + 3]; + + int qsum0_64byte = 0, qsum1_64byte = 0, qsum2_64byte = 0, qsum3_64byte = 0; + if (num_blocks_64 > 0) + { + __m512i ones = _mm512_set1_epi8(1); + __m512i sum0 = _mm512_setzero_epi32(); + __m512i sum1 = _mm512_setzero_epi32(); + __m512i sum2 = _mm512_setzero_epi32(); + __m512i sum3 = _mm512_setzero_epi32(); + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + sum0 = _mm512_dpbusd_epi32(sum0, ones, _mm512_loadu_si512((const __m512i*)(q0 + off))); + sum1 = _mm512_dpbusd_epi32(sum1, ones, _mm512_loadu_si512((const __m512i*)(q1 + off))); + sum2 = _mm512_dpbusd_epi32(sum2, ones, _mm512_loadu_si512((const __m512i*)(q2 + off))); + sum3 = _mm512_dpbusd_epi32(sum3, ones, _mm512_loadu_si512((const __m512i*)(q3 + off))); + } + qsum0_64byte = _mm512_reduce_add_epi32(sum0); + qsum1_64byte = _mm512_reduce_add_epi32(sum1); + qsum2_64byte = _mm512_reduce_add_epi32(sum2); + qsum3_64byte = _mm512_reduce_add_epi32(sum3); + } + + int j = 0; + for (; j + 3 < n; j += 4) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + const signed char* k2 = K + (j + 2) * d; + const signed char* k3 = K + (j + 3) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc02 = _mm512_setzero_epi32(); + __m512i acc03 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc12 = _mm512_setzero_epi32(); + __m512i acc13 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc22 = _mm512_setzero_epi32(); + __m512i acc23 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + __m512i acc32 = _mm512_setzero_epi32(); + __m512i acc33 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar02 = 0, scalar03 = 0; + int scalar10 = 0, scalar11 = 0, scalar12 = 0, scalar13 = 0; + int scalar20 = 0, scalar21 = 0, scalar22 = 0, scalar23 = 0; + int scalar30 = 0, scalar31 = 0, scalar32 = 0, scalar33 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k2_512 = _mm512_loadu_si512((const __m512i*)(k2 + off)); + __m512i k3_512 = _mm512_loadu_si512((const __m512i*)(k3 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + __m512i k2_u8 = _mm512_add_epi8(k2_512, _mm512_set1_epi8(128)); + __m512i k3_u8 = _mm512_add_epi8(k3_512, _mm512_set1_epi8(128)); + + acc00 = _mm512_dpbusd_epi32(acc00, k0_u8, q0_512); + acc01 = _mm512_dpbusd_epi32(acc01, k1_u8, q0_512); + acc02 = _mm512_dpbusd_epi32(acc02, k2_u8, q0_512); + acc03 = _mm512_dpbusd_epi32(acc03, k3_u8, q0_512); + acc10 = _mm512_dpbusd_epi32(acc10, k0_u8, q1_512); + acc11 = _mm512_dpbusd_epi32(acc11, k1_u8, q1_512); + acc12 = _mm512_dpbusd_epi32(acc12, k2_u8, q1_512); + acc13 = _mm512_dpbusd_epi32(acc13, k3_u8, q1_512); + acc20 = _mm512_dpbusd_epi32(acc20, k0_u8, q2_512); + acc21 = _mm512_dpbusd_epi32(acc21, k1_u8, q2_512); + acc22 = _mm512_dpbusd_epi32(acc22, k2_u8, q2_512); + acc23 = _mm512_dpbusd_epi32(acc23, k3_u8, q2_512); + acc30 = _mm512_dpbusd_epi32(acc30, k0_u8, q3_512); + acc31 = _mm512_dpbusd_epi32(acc31, k1_u8, q3_512); + acc32 = _mm512_dpbusd_epi32(acc32, k2_u8, q3_512); + acc33 = _mm512_dpbusd_epi32(acc33, k3_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar00 += q0[k] * k0[k]; + scalar01 += q0[k] * k1[k]; + scalar02 += q0[k] * k2[k]; + scalar03 += q0[k] * k3[k]; + scalar10 += q1[k] * k0[k]; + scalar11 += q1[k] * k1[k]; + scalar12 += q1[k] * k2[k]; + scalar13 += q1[k] * k3[k]; + scalar20 += q2[k] * k0[k]; + scalar21 += q2[k] * k1[k]; + scalar22 += q2[k] * k2[k]; + scalar23 += q2[k] * k3[k]; + scalar30 += q3[k] * k0[k]; + scalar31 += q3[k] * k1[k]; + scalar32 += q3[k] * k2[k]; + scalar33 += q3[k] * k3[k]; + } + } + + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale02 = qs0 * kscales[j + 2]; + float descale03 = qs0 * kscales[j + 3]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale12 = qs1 * kscales[j + 2]; + float descale13 = qs1 * kscales[j + 3]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale22 = qs2 * kscales[j + 2]; + float descale23 = qs2 * kscales[j + 3]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float descale32 = qs3 * kscales[j + 2]; + float descale33 = qs3 * kscales[j + 3]; + float sum00 = (float)(_mm512_reduce_add_epi32(acc00) - 128 * qsum0_64byte) + (float)scalar00; + float sum01 = (float)(_mm512_reduce_add_epi32(acc01) - 128 * qsum0_64byte) + (float)scalar01; + float sum02 = (float)(_mm512_reduce_add_epi32(acc02) - 128 * qsum0_64byte) + (float)scalar02; + float sum03 = (float)(_mm512_reduce_add_epi32(acc03) - 128 * qsum0_64byte) + (float)scalar03; + float sum10 = (float)(_mm512_reduce_add_epi32(acc10) - 128 * qsum1_64byte) + (float)scalar10; + float sum11 = (float)(_mm512_reduce_add_epi32(acc11) - 128 * qsum1_64byte) + (float)scalar11; + float sum12 = (float)(_mm512_reduce_add_epi32(acc12) - 128 * qsum1_64byte) + (float)scalar12; + float sum13 = (float)(_mm512_reduce_add_epi32(acc13) - 128 * qsum1_64byte) + (float)scalar13; + float sum20 = (float)(_mm512_reduce_add_epi32(acc20) - 128 * qsum2_64byte) + (float)scalar20; + float sum21 = (float)(_mm512_reduce_add_epi32(acc21) - 128 * qsum2_64byte) + (float)scalar21; + float sum22 = (float)(_mm512_reduce_add_epi32(acc22) - 128 * qsum2_64byte) + (float)scalar22; + float sum23 = (float)(_mm512_reduce_add_epi32(acc23) - 128 * qsum2_64byte) + (float)scalar23; + float sum30 = (float)(_mm512_reduce_add_epi32(acc30) - 128 * qsum3_64byte) + (float)scalar30; + float sum31 = (float)(_mm512_reduce_add_epi32(acc31) - 128 * qsum3_64byte) + (float)scalar31; + float sum32 = (float)(_mm512_reduce_add_epi32(acc32) - 128 * qsum3_64byte) + (float)scalar32; + float sum33 = (float)(_mm512_reduce_add_epi32(acc33) - 128 * qsum3_64byte) + (float)scalar33; + S[(i + 0) * n + j + 0] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 0) * n + j + 2] = sum02 * descale02 * scale; + S[(i + 0) * n + j + 3] = sum03 * descale03 * scale; + S[(i + 1) * n + j + 0] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 1) * n + j + 2] = sum12 * descale12 * scale; + S[(i + 1) * n + j + 3] = sum13 * descale13 * scale; + S[(i + 2) * n + j + 0] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 2) * n + j + 2] = sum22 * descale22 * scale; + S[(i + 2) * n + j + 3] = sum23 * descale23 * scale; + S[(i + 3) * n + j + 0] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + S[(i + 3) * n + j + 2] = sum32 * descale32 * scale; + S[(i + 3) * n + j + 3] = sum33 * descale33 * scale; + } + for (; j + 1 < n; j += 2) + { + const signed char* k0 = K + j * d; + const signed char* k1 = K + (j + 1) * d; + + __m512i acc00 = _mm512_setzero_epi32(); + __m512i acc01 = _mm512_setzero_epi32(); + __m512i acc10 = _mm512_setzero_epi32(); + __m512i acc11 = _mm512_setzero_epi32(); + __m512i acc20 = _mm512_setzero_epi32(); + __m512i acc21 = _mm512_setzero_epi32(); + __m512i acc30 = _mm512_setzero_epi32(); + __m512i acc31 = _mm512_setzero_epi32(); + int scalar00 = 0, scalar01 = 0, scalar10 = 0, scalar11 = 0; + int scalar20 = 0, scalar21 = 0, scalar30 = 0, scalar31 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k0_512 = _mm512_loadu_si512((const __m512i*)(k0 + off)); + __m512i k1_512 = _mm512_loadu_si512((const __m512i*)(k1 + off)); + __m512i k0_u8 = _mm512_add_epi8(k0_512, _mm512_set1_epi8(128)); + __m512i k1_u8 = _mm512_add_epi8(k1_512, _mm512_set1_epi8(128)); + + acc00 = _mm512_dpbusd_epi32(acc00, k0_u8, q0_512); + acc01 = _mm512_dpbusd_epi32(acc01, k1_u8, q0_512); + acc10 = _mm512_dpbusd_epi32(acc10, k0_u8, q1_512); + acc11 = _mm512_dpbusd_epi32(acc11, k1_u8, q1_512); + acc20 = _mm512_dpbusd_epi32(acc20, k0_u8, q2_512); + acc21 = _mm512_dpbusd_epi32(acc21, k1_u8, q2_512); + acc30 = _mm512_dpbusd_epi32(acc30, k0_u8, q3_512); + acc31 = _mm512_dpbusd_epi32(acc31, k1_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar00 += q0[k] * k0[k]; + scalar01 += q0[k] * k1[k]; + scalar10 += q1[k] * k0[k]; + scalar11 += q1[k] * k1[k]; + scalar20 += q2[k] * k0[k]; + scalar21 += q2[k] * k1[k]; + scalar30 += q3[k] * k0[k]; + scalar31 += q3[k] * k1[k]; + } + } + + float descale00 = qs0 * kscales[j]; + float descale01 = qs0 * kscales[j + 1]; + float descale10 = qs1 * kscales[j]; + float descale11 = qs1 * kscales[j + 1]; + float descale20 = qs2 * kscales[j]; + float descale21 = qs2 * kscales[j + 1]; + float descale30 = qs3 * kscales[j]; + float descale31 = qs3 * kscales[j + 1]; + float sum00 = (float)(_mm512_reduce_add_epi32(acc00) - 128 * qsum0_64byte) + (float)scalar00; + float sum01 = (float)(_mm512_reduce_add_epi32(acc01) - 128 * qsum0_64byte) + (float)scalar01; + float sum10 = (float)(_mm512_reduce_add_epi32(acc10) - 128 * qsum1_64byte) + (float)scalar10; + float sum11 = (float)(_mm512_reduce_add_epi32(acc11) - 128 * qsum1_64byte) + (float)scalar11; + float sum20 = (float)(_mm512_reduce_add_epi32(acc20) - 128 * qsum2_64byte) + (float)scalar20; + float sum21 = (float)(_mm512_reduce_add_epi32(acc21) - 128 * qsum2_64byte) + (float)scalar21; + float sum30 = (float)(_mm512_reduce_add_epi32(acc30) - 128 * qsum3_64byte) + (float)scalar30; + float sum31 = (float)(_mm512_reduce_add_epi32(acc31) - 128 * qsum3_64byte) + (float)scalar31; + S[(i + 0) * n + j] = sum00 * descale00 * scale; + S[(i + 0) * n + j + 1] = sum01 * descale01 * scale; + S[(i + 1) * n + j] = sum10 * descale10 * scale; + S[(i + 1) * n + j + 1] = sum11 * descale11 * scale; + S[(i + 2) * n + j] = sum20 * descale20 * scale; + S[(i + 2) * n + j + 1] = sum21 * descale21 * scale; + S[(i + 3) * n + j] = sum30 * descale30 * scale; + S[(i + 3) * n + j + 1] = sum31 * descale31 * scale; + } + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + + __m512i acc0 = _mm512_setzero_epi32(); + __m512i acc1 = _mm512_setzero_epi32(); + __m512i acc2 = _mm512_setzero_epi32(); + __m512i acc3 = _mm512_setzero_epi32(); + int scalar0 = 0, scalar1 = 0, scalar2 = 0, scalar3 = 0; + + for (int b = 0; b < num_blocks_64; b++) + { + int off = b * 64; + __m512i q0_512 = _mm512_loadu_si512((const __m512i*)(q0 + off)); + __m512i q1_512 = _mm512_loadu_si512((const __m512i*)(q1 + off)); + __m512i q2_512 = _mm512_loadu_si512((const __m512i*)(q2 + off)); + __m512i q3_512 = _mm512_loadu_si512((const __m512i*)(q3 + off)); + + __m512i k_512 = _mm512_loadu_si512((const __m512i*)(kptr + off)); + __m512i k_u8 = _mm512_add_epi8(k_512, _mm512_set1_epi8(128)); + + acc0 = _mm512_dpbusd_epi32(acc0, k_u8, q0_512); + acc1 = _mm512_dpbusd_epi32(acc1, k_u8, q1_512); + acc2 = _mm512_dpbusd_epi32(acc2, k_u8, q2_512); + acc3 = _mm512_dpbusd_epi32(acc3, k_u8, q3_512); + } + + int tail_start = num_blocks_64 * 64; + if (tail_start < d) + { + for (int k = tail_start; k < d; k++) + { + scalar0 += q0[k] * kptr[k]; + scalar1 += q1[k] * kptr[k]; + scalar2 += q2[k] * kptr[k]; + scalar3 += q3[k] * kptr[k]; + } + } + + float descale = kscales[j]; + float sum0 = (float)(_mm512_reduce_add_epi32(acc0) - 128 * qsum0_64byte) + (float)scalar0; + float sum1 = (float)(_mm512_reduce_add_epi32(acc1) - 128 * qsum1_64byte) + (float)scalar1; + float sum2 = (float)(_mm512_reduce_add_epi32(acc2) - 128 * qsum2_64byte) + (float)scalar2; + float sum3 = (float)(_mm512_reduce_add_epi32(acc3) - 128 * qsum3_64byte) + (float)scalar3; + S[(i + 0) * n + j] = sum0 * qs0 * descale * scale; + S[(i + 1) * n + j] = sum1 * qs1 * descale * scale; + S[(i + 2) * n + j] = sum2 * qs2 * descale * scale; + S[(i + 3) * n + j] = sum3 * qs3 * descale * scale; + } + } + for (; i < m; i++) + { + qk_int8_gemm_row_avx512vnni_kernel(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} +#endif // __AVX512VNNI__ + +static void qk_int8_gemm_tiled(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) { #if __AVX512F__ - qk_int8_gemm_tiled_avx512(S, Q, K, qscales, kscales, m, n, d, scale); +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_tiled_avx512vnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if __AVX512VNNI__ + qk_int8_gemm_tiled_avx512vnni_kernel(S, Q, K, qscales, kscales, m, n, d, scale); #else -#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ + qk_int8_gemm_tiled_avx512_kernel(S, Q, K, qscales, kscales, m, n, d, scale); +#endif +#else +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + qk_int8_gemm_tiled_avx512vnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNIINT8 && __AVX__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni_int8()) + { + qk_int8_gemm_tiled_avxvnniint8(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + qk_int8_gemm_tiled_avxvnni(S, Q, K, qscales, kscales, m, n, d, scale); + return; + } +#endif +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVXVNNIINT8__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx2()) { qk_int8_gemm_tiled_avx2(S, Q, K, qscales, kscales, m, n, d, scale); @@ -1612,18 +2386,18 @@ static void qk_int8_gemm_tiled_dispatch(float* S, } #endif #if __AVX2__ - qk_int8_gemm_tiled_avx2(S, Q, K, qscales, kscales, m, n, d, scale); + qk_int8_gemm_tiled_avx2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); #elif __SSE2__ - qk_int8_gemm_tiled_sse2(S, Q, K, qscales, kscales, m, n, d, scale); + qk_int8_gemm_tiled_sse2_kernel(S, Q, K, qscales, kscales, m, n, d, scale); #else - qk_int8_gemm_tiled_scalar(S, Q, K, qscales, kscales, m, n, d, scale); + qk_int8_gemm_tiled_scalar_kernel(S, Q, K, qscales, kscales, m, n, d, scale); #endif #endif } // ------------------- Decode PV GEMV Int8 ------------------- -static inline void decode_pv_gemv_int8_scalar(float* out, const float* s, +static inline void decode_pv_gemv_int8_scalar_kernel(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) { @@ -1647,7 +2421,7 @@ static inline void decode_pv_gemv_int8_scalar(float* out, const float* s, } #if __SSE2__ -static inline void decode_pv_gemv_int8_sse2(float* out, const float* s, +static inline void decode_pv_gemv_int8_sse2_kernel(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) { @@ -1672,7 +2446,7 @@ static inline void decode_pv_gemv_int8_sse2(float* out, const float* s, __m128 vfp = _mm_cvtepi32_ps(v32); oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); } - _mm_storeu_ps(out + k, oval); + _mm_storeu_ps(out + k, _mm_add_ps(_mm_loadu_ps(out + k), oval)); } for (; k < k_end; k++) { @@ -1682,14 +2456,14 @@ static inline void decode_pv_gemv_int8_sse2(float* out, const float* s, float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; sum += p_invscale * V[(n_start + j) * out_d + k]; } - out[k] = sum; + out[k] += sum; } } } #endif // __SSE2__ #if __AVX2__ -inline void __attribute__((noinline)) decode_pv_gemv_int8_avx2(float* out, const float* s, +inline void decode_pv_gemv_int8_avx2_kernel(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) { @@ -1712,7 +2486,7 @@ inline void __attribute__((noinline)) decode_pv_gemv_int8_avx2(float* out, const __m256 vfp = _mm256_cvtepi32_ps(v32); oval = _mm256_fmadd_ps(pvec, vfp, oval); } - _mm256_storeu_ps(out + k, oval); + _mm256_storeu_ps(out + k, _mm256_add_ps(_mm256_loadu_ps(out + k), oval)); } for (; k + 3 < k_end; k += 4) { @@ -1740,62 +2514,114 @@ inline void __attribute__((noinline)) decode_pv_gemv_int8_avx2(float* out, const #endif // __AVX__ #if __AVX512F__ -static inline void decode_pv_gemv_int8_avx512(float* out, const float* s, +static inline void decode_pv_gemv_int8_avx512_kernel(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) { const int num_blocks = (out_d + 31) / 32; - for (int vb = 0; vb < num_blocks; vb++) + int j = 0; + for (; j + 1 < block_n; j += 2) { - int k_start = vb * 32; - int k_end = k_start + 32 < out_d ? k_start + 32 : out_d; + float p0 = s[j]; + float p1 = s[j + 1]; + const signed char* v0 = V + (n_start + j) * out_d; + const signed char* v1 = V + (n_start + j + 1) * out_d; - int k = k_start; - for (; k + 15 < k_end; k += 16) + int k = 0; + for (; k + 31 < out_d; k += 32) { - __m512 oval = _mm512_setzero_ps(); - for (int j = 0; j < block_n; j++) - { - float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; - __m512 pvec = _mm512_set1_ps(p_invscale); - __m128i v8 = _mm_loadu_si128((const __m128i*)(V + (n_start + j) * out_d + k)); - __m512i v32 = _mm512_cvtepi8_epi32(v8); - __m512 vfp = _mm512_cvtepi32_ps(v32); - oval = _mm512_fmadd_ps(pvec, vfp, oval); - } + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + + __m128i v8_0a = _mm_loadu_si128((const __m128i*)(v0 + k)); + __m128i v8_0b = _mm_loadu_si128((const __m128i*)(v0 + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0a)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0b)), oval1); + + __m128i v8_1a = _mm_loadu_si128((const __m128i*)(v1 + k)); + __m128i v8_1b = _mm_loadu_si128((const __m128i*)(v1 + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1a)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1b)), oval1); + + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + for (; k + 15 < out_d; k += 16) + { + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + + __m512 oval = _mm512_loadu_ps(out + k); + + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(v0 + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p0_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), oval); + + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(v1 + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p1_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), oval); + _mm512_storeu_ps(out + k, oval); } - for (; k + 7 < k_end; k += 8) + for (; k < out_d; k++) { - __m256 oval = _mm256_setzero_ps(); - for (int j = 0; j < block_n; j++) - { - float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; - __m256 pvec = _mm256_set1_ps(p_invscale); - __m128i v8 = _mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k)); - __m256i v32 = _mm256_cvtepi8_epi32(v8); - __m256 vfp = _mm256_cvtepi32_ps(v32); - oval = _mm256_fmadd_ps(pvec, vfp, oval); - } - _mm256_storeu_ps(out + k, oval); + int vb = k / 32; + float p0_invscale = p0 / vscales[(n_start + j) * num_blocks + vb]; + float p1_invscale = p1 / vscales[(n_start + j + 1) * num_blocks + vb]; + out[k] += p0_invscale * v0[k] + p1_invscale * v1[k]; } - for (; k < k_end; k++) + } + for (; j < block_n; j++) + { + float p = s[j]; + const signed char* v = V + (n_start + j) * out_d; + + int k = 0; + for (; k + 31 < out_d; k += 32) { - float sum = 0.f; - for (int j = 0; j < block_n; j++) - sum += s[j] / vscales[(n_start + j) * num_blocks + vb] * V[(n_start + j) * out_d + k]; - out[k] = sum; + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + + __m128i v8_0 = _mm_loadu_si128((const __m128i*)(v + k)); + __m128i v8_1 = _mm_loadu_si128((const __m128i*)(v + k + 16)); + oval0 = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_0)), oval0); + oval1 = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8_1)), oval1); + + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + for (; k + 15 < out_d; k += 16) + { + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + + __m512 oval = _mm512_loadu_ps(out + k); + __m128i v8 = _mm_loadu_si128((const __m128i*)(v + k)); + oval = _mm512_fmadd_ps(_mm512_set1_ps(p_invscale), _mm512_cvtepi32_ps(_mm512_cvtepi8_epi32(v8)), oval); + _mm512_storeu_ps(out + k, oval); + } + for (; k < out_d; k++) + { + int vb = k / 32; + float p_invscale = p / vscales[(n_start + j) * num_blocks + vb]; + out[k] += p_invscale * v[k]; } } } #endif // __AVX512F__ -static void decode_pv_gemv_int8_dispatch(float* out, const float* s, +static void decode_pv_gemv_int8(float* out, const float* s, const signed char* V, const float* vscales, int n_start, int block_n, int out_d) { #if __AVX512F__ - decode_pv_gemv_int8_avx512(out, s, V, vscales, n_start, block_n, out_d); + decode_pv_gemv_int8_avx512_kernel(out, s, V, vscales, n_start, block_n, out_d); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -1807,16 +2633,16 @@ static void decode_pv_gemv_int8_dispatch(float* out, const float* s, #if __AVX2__ decode_pv_gemv_int8_avx2(out, s, V, vscales, n_start, block_n, out_d); #elif __SSE2__ - decode_pv_gemv_int8_sse2(out, s, V, vscales, n_start, block_n, out_d); + decode_pv_gemv_int8_sse2_kernel(out, s, V, vscales, n_start, block_n, out_d); #else - decode_pv_gemv_int8_scalar(out, s, V, vscales, n_start, block_n, out_d); + decode_pv_gemv_int8_scalar_kernel(out, s, V, vscales, n_start, block_n, out_d); #endif #endif } // ------------------- Prefill PV Float×Int8 GEMM Row-wise ------------------- -static inline void pv_float_int8_gemm_row_scalar(float* out, const float* p_row, +static inline void pv_float_int8_gemm_row_scalar_kernel(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) { @@ -1840,7 +2666,7 @@ static inline void pv_float_int8_gemm_row_scalar(float* out, const float* p_row, } #if __SSE2__ -static inline void pv_float_int8_gemm_row_sse2(float* out, const float* p_row, +static inline void pv_float_int8_gemm_row_sse2_kernel(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) { @@ -1876,7 +2702,7 @@ static inline void pv_float_int8_gemm_row_sse2(float* out, const float* p_row, #endif // __SSE2__ #if __AVX2__ -inline void __attribute__((noinline)) pv_float_int8_gemm_row_avx2(float* out, const float* p_row, +inline void pv_float_int8_gemm_row_avx2_kernel(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) { @@ -1926,7 +2752,7 @@ inline void __attribute__((noinline)) pv_float_int8_gemm_row_avx2(float* out, co #endif // __AVX2__ #if __AVX512F__ -static inline void pv_float_int8_gemm_row_avx512(float* out, const float* p_row, +static inline void pv_float_int8_gemm_row_avx512_kernel(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) { @@ -1975,12 +2801,12 @@ static inline void pv_float_int8_gemm_row_avx512(float* out, const float* p_row, } #endif // __AVX512F__ -static void pv_float_int8_gemm_row_dispatch(float* out, const float* p_row, +static void pv_float_int8_gemm_row(float* out, const float* p_row, const signed char* V, const float* vscales, int n, int out_d) { #if __AVX512F__ - pv_float_int8_gemm_row_avx512(out, p_row, V, vscales, n, out_d); + pv_float_int8_gemm_row_avx512_kernel(out, p_row, V, vscales, n, out_d); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -1992,23 +2818,23 @@ static void pv_float_int8_gemm_row_dispatch(float* out, const float* p_row, #if __AVX2__ pv_float_int8_gemm_row_avx2(out, p_row, V, vscales, n, out_d); #elif __SSE2__ - pv_float_int8_gemm_row_sse2(out, p_row, V, vscales, n, out_d); + pv_float_int8_gemm_row_sse2_kernel(out, p_row, V, vscales, n, out_d); #else - pv_float_int8_gemm_row_scalar(out, p_row, V, vscales, n, out_d); + pv_float_int8_gemm_row_scalar_kernel(out, p_row, V, vscales, n, out_d); #endif #endif } // ------------------- PV Float×Int8 FMA Block (online softmax) ------------------- -static inline void pv_float_int8_fma_block_scalar(float* out, float p_invscale, const signed char* v, int len) +static inline void pv_float_int8_fma_block_scalar_kernel(float* out, float p_invscale, const signed char* v, int len) { for (int k = 0; k < len; k++) out[k] += p_invscale * v[k]; } #if __SSE2__ -static inline void pv_float_int8_fma_block_sse2(float* out, float p_invscale, const signed char* v, int len) +static inline void pv_float_int8_fma_block_sse2_kernel(float* out, float p_invscale, const signed char* v, int len) { __m128 pvec = _mm_set1_ps(p_invscale); int k = 0; @@ -2025,7 +2851,7 @@ static inline void pv_float_int8_fma_block_sse2(float* out, float p_invscale, co #endif // __SSE2__ #if __AVX2__ -inline void __attribute__((noinline)) pv_float_int8_fma_block_avx2(float* out, float p_invscale, const signed char* v, int len) +inline void pv_float_int8_fma_block_avx2_kernel(float* out, float p_invscale, const signed char* v, int len) { __m256 pvec = _mm256_set1_ps(p_invscale); int k = 0; @@ -2050,7 +2876,7 @@ inline void __attribute__((noinline)) pv_float_int8_fma_block_avx2(float* out, f #endif // __AVX2__ #if __AVX512F__ -static inline void pv_float_int8_fma_block_avx512(float* out, float p_invscale, const signed char* v, int len) +static inline void pv_float_int8_fma_block_avx512_kernel(float* out, float p_invscale, const signed char* v, int len) { __m512 pvec = _mm512_set1_ps(p_invscale); int k = 0; @@ -2074,10 +2900,10 @@ static inline void pv_float_int8_fma_block_avx512(float* out, float p_invscale, } #endif // __AVX512F__ -static void pv_float_int8_fma_block_dispatch(float* out, float p_invscale, const signed char* v, int len) +static void pv_float_int8_fma_block(float* out, float p_invscale, const signed char* v, int len) { #if __AVX512F__ - pv_float_int8_fma_block_avx512(out, p_invscale, v, len); + pv_float_int8_fma_block_avx512_kernel(out, p_invscale, v, len); #else #if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ if (ncnn::cpu_support_x86_avx2()) @@ -2089,15 +2915,15 @@ static void pv_float_int8_fma_block_dispatch(float* out, float p_invscale, const #if __AVX2__ pv_float_int8_fma_block_avx2(out, p_invscale, v, len); #elif __SSE2__ - pv_float_int8_fma_block_sse2(out, p_invscale, v, len); + pv_float_int8_fma_block_sse2_kernel(out, p_invscale, v, len); #else - pv_float_int8_fma_block_scalar(out, p_invscale, v, len); + pv_float_int8_fma_block_scalar_kernel(out, p_invscale, v, len); #endif #endif } #if __AVX2__ -inline void __attribute__((noinline)) pv_float_int8_gemm_tile_avx2(float* O, const float* P, +inline void pv_float_int8_gemm_tile_avx2_kernel(float* O, const float* P, const signed char* V, const float* vscales, int block_m, int block_n, int out_embed_dim) { diff --git a/src/layer/x86/sdpa_x86_xop.cpp b/src/layer/x86/sdpa_x86_xop.cpp new file mode 100644 index 000000000000..9f0216a47b64 --- /dev/null +++ b/src/layer/x86/sdpa_x86_xop.cpp @@ -0,0 +1,56 @@ +// Copyright 2025 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_int8.h" + +#if __XOP__ + +void decode_qk_dot_int8_xop(float* s, const signed char* q, const signed char* K, const float* qscales, const float* kscales, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_int8_xop_kernel(s, q, K, qscales, kscales, n_start, block_n, d, scale); +} + +void qk_int8_gemm_row_xop(float* s_row, const signed char* q_row, const signed char* K, float qscale, const float* kscales, int n, int d, float scale) +{ + const int num_blocks = (d + 31) / 32; + int j = 0; + for (; j < n; j++) + { + const signed char* kptr = K + j * d; + int sum = 0; + for (int b = 0; b < num_blocks; b++) + { + int off = b * 32; + int len = std::min(32, d - off); + if (len <= 0) continue; + sum += qk_int8_dot_block_xop_kernel(q_row + off, kptr + off, len); + } + s_row[j] = (float)sum * qscale * kscales[j] * scale; + } +} + +void qk_int8_gemm_tiled_xop(float* S, const signed char* Q, const signed char* K, const float* qscales, const float* kscales, int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + qk_int8_gemm_row_xop(S + i * n, Q + i * d, K, qscales[i], kscales, n, d, scale); + } +} + +#endif // __XOP__ + +} // namespace ncnn From ccfb1c828782cd54afd9cb05e591e3ab871a898e Mon Sep 17 00:00:00 2001 From: futz12 <56149058+futz12@users.noreply.github.com> Date: Sun, 3 May 2026 13:49:17 +0000 Subject: [PATCH 28/53] apply code-format changes --- src/layer/x86/sdpa_x86.cpp | 76 ++++++------- src/layer/x86/sdpa_x86_int8.h | 196 +++++++++++++++++----------------- 2 files changed, 135 insertions(+), 137 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 6f27a1d8d5c3..7e4b37874926 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3649,14 +3649,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to bool use_kv_cache = kv_cache && past_seqlen > 0; bool cache_valid = false; if (use_kv_cache - && cached_kv_seqlen == past_seqlen - && cached_num_group == num_group - && cached_embed_dim == embed_dim - && cached_out_embed_dim == out_embed_dim - && !cached_key_int8.empty() - && !cached_key_scales.empty() - && !cached_value_int8.empty() - && !cached_value_scales.empty()) + && cached_kv_seqlen == past_seqlen + && cached_num_group == num_group + && cached_embed_dim == embed_dim + && cached_out_embed_dim == out_embed_dim + && !cached_key_int8.empty() + && !cached_key_scales.empty() + && !cached_value_int8.empty() + && !cached_value_scales.empty()) { cache_valid = true; for (int g = 0; g < num_group; g++) @@ -3761,10 +3761,10 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int block_n = std::min(BLOCK_N, dst_seqlen - n_start); decode_qk_dot_int8(s, q_int8, - key_int8_head.row(0), - q_scale, - key_scales_head.row(0), - n_start, block_n, embed_dim, _scale); + key_int8_head.row(0), + q_scale, + key_scales_head.row(0), + n_start, block_n, embed_dim, _scale); if (mask_ptr) { @@ -3793,9 +3793,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to l += l_add; decode_pv_gemv_int8(out, s, - value_int8_head.row(0), - value_scales_head.row(0), - n_start, block_n, out_embed_dim); + value_int8_head.row(0), + value_scales_head.row(0), + n_start, block_n, out_embed_dim); m = new_m; } @@ -3852,10 +3852,10 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int block_n = std::min(BLOCK_N, dst_seqlen - n_start); decode_qk_dot_int8(s, q_int8, - key_int8_head.row(0), - q_scale, - key_scales_head.row(0), - n_start, block_n, embed_dim, _scale); + key_int8_head.row(0), + q_scale, + key_scales_head.row(0), + n_start, block_n, embed_dim, _scale); if (mask_ptr) { @@ -3884,9 +3884,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to l += l_add; decode_pv_gemv_int8(out, s, - value_int8_head.row(0), - value_scales_head.row(0), - n_start, block_n, out_embed_dim); + value_int8_head.row(0), + value_scales_head.row(0), + n_start, block_n, out_embed_dim); m = new_m; } @@ -3912,7 +3912,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - // else: fall through to fp32 decode path below + // else: fall through to fp32 decode path below #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_heads; q++) @@ -3980,20 +3980,20 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (block_m == 1) { qk_int8_gemm_row(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0)[0], - key_scales_head.row(n_start), - block_n, embed_dim, _scale); + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0)[0], + key_scales_head.row(n_start), + block_n, embed_dim, _scale); } else { qk_int8_gemm_tiled(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0), - key_scales_head.row(n_start), - block_m, block_n, embed_dim, _scale); + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); } for (int i = 0; i < block_m; i++) @@ -4038,15 +4038,15 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (kv_cache) { pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, - value_int8_head.row(n_start), - value_scales_head.row(n_start), - block_m, block_n, out_embed_dim); + value_int8_head.row(n_start), + value_scales_head.row(n_start), + block_m, block_n, out_embed_dim); } else { pv_gemm_dispatch(o_accum_head.row(0), p_vec_ptr, - value.channel(q / num_heads_per_group).row(n_start), - block_m, block_n, out_embed_dim); + value.channel(q / num_heads_per_group).row(n_start), + block_m, block_n, out_embed_dim); } } diff --git a/src/layer/x86/sdpa_x86_int8.h b/src/layer/x86/sdpa_x86_int8.h index 9969e901c824..4b848c12d823 100644 --- a/src/layer/x86/sdpa_x86_int8.h +++ b/src/layer/x86/sdpa_x86_int8.h @@ -978,8 +978,8 @@ static inline void decode_qk_dot_int8_avx512_kernel(float* s, const signed char* { int off = b * 64; acc = _mm512_comp_dpwssd_epi32(acc, - _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(q + off))), - _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(kptr + off)))); + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(q + off))), + _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i*)(kptr + off)))); } for (int i = num_blocks_64 * 64; i < d; i++) scalar += q[i] * kptr[i]; @@ -991,8 +991,8 @@ static inline void decode_qk_dot_int8_avx512_kernel(float* s, const signed char* #endif // __AVX512F__ static void decode_qk_dot_int8(float* s, const signed char* q, - const signed char* K, const float* qscales, const float* kscales, - int n_start, int block_n, int d, float scale) + const signed char* K, const float* qscales, const float* kscales, + int n_start, int block_n, int d, float scale) { #if __AVX512F__ #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ @@ -1347,8 +1347,8 @@ static inline void qk_int8_gemm_row_avx512vnni_kernel(float* s_row, #endif // __AVX512VNNI__ static void qk_int8_gemm_row(float* s_row, - const signed char* q_row, const signed char* K, float qscale, const float* kscales, - int n, int d, float scale) + const signed char* q_row, const signed char* K, float qscale, const float* kscales, + int n, int d, float scale) { #if __AVX512F__ #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ @@ -2339,9 +2339,9 @@ static inline void qk_int8_gemm_tiled_avx512vnni_kernel(float* S, #endif // __AVX512VNNI__ static void qk_int8_gemm_tiled(float* S, - const signed char* Q, const signed char* K, - const float* qscales, const float* kscales, - int m, int n, int d, float scale) + const signed char* Q, const signed char* K, + const float* qscales, const float* kscales, + int m, int n, int d, float scale) { #if __AVX512F__ #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && !__AVX512VNNI__ @@ -2439,9 +2439,9 @@ static inline void decode_pv_gemv_int8_sse2_kernel(float* out, const float* s, { float p_invscale = s[j] / vscales[(n_start + j) * num_blocks + vb]; __m128 pvec = _mm_set1_ps(p_invscale); - __m128i v8 = _mm_set_epi8(0,0,0,0,0,0,0,0,0,0,0,0, - V[(n_start+j)*out_d+k+3], V[(n_start+j)*out_d+k+2], - V[(n_start+j)*out_d+k+1], V[(n_start+j)*out_d+k+0]); + __m128i v8 = _mm_set_epi8(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + V[(n_start + j) * out_d + k + 3], V[(n_start + j) * out_d + k + 2], + V[(n_start + j) * out_d + k + 1], V[(n_start + j) * out_d + k + 0]); __m128i v32 = _mm_cvtepi8_epi32(v8); __m128 vfp = _mm_cvtepi32_ps(v32); oval = _mm_add_ps(oval, _mm_mul_ps(pvec, vfp)); @@ -2617,8 +2617,8 @@ static inline void decode_pv_gemv_int8_avx512_kernel(float* out, const float* s, #endif // __AVX512F__ static void decode_pv_gemv_int8(float* out, const float* s, - const signed char* V, const float* vscales, - int n_start, int block_n, int out_d) + const signed char* V, const float* vscales, + int n_start, int block_n, int out_d) { #if __AVX512F__ decode_pv_gemv_int8_avx512_kernel(out, s, V, vscales, n_start, block_n, out_d); @@ -2802,8 +2802,8 @@ static inline void pv_float_int8_gemm_row_avx512_kernel(float* out, const float* #endif // __AVX512F__ static void pv_float_int8_gemm_row(float* out, const float* p_row, - const signed char* V, const float* vscales, - int n, int out_d) + const signed char* V, const float* vscales, + int n, int out_d) { #if __AVX512F__ pv_float_int8_gemm_row_avx512_kernel(out, p_row, V, vscales, n, out_d); @@ -2924,105 +2924,105 @@ static void pv_float_int8_fma_block(float* out, float p_invscale, const signed c #if __AVX2__ inline void pv_float_int8_gemm_tile_avx2_kernel(float* O, const float* P, - const signed char* V, const float* vscales, - int block_m, int block_n, int out_embed_dim) + const signed char* V, const float* vscales, + int block_m, int block_n, int out_embed_dim) { const int v_num_blocks = (out_embed_dim + 31) / 32; -int i = 0; -for (; i + 1 < block_m; i += 2) -{ - const float* p0 = P + i * block_n; - const float* p1 = P + (i + 1) * block_n; - for (int vb = 0; vb < v_num_blocks; vb++) + int i = 0; + for (; i + 1 < block_m; i += 2) { - int k_start = vb * 32; - int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; - int len = k_end - k_start; - if (len <= 0) continue; - if (len == 32) - { - __m256 acc0_0 = _mm256_setzero_ps(); - __m256 acc0_1 = _mm256_setzero_ps(); - __m256 acc0_2 = _mm256_setzero_ps(); - __m256 acc0_3 = _mm256_setzero_ps(); - __m256 acc1_0 = _mm256_setzero_ps(); - __m256 acc1_1 = _mm256_setzero_ps(); - __m256 acc1_2 = _mm256_setzero_ps(); - __m256 acc1_3 = _mm256_setzero_ps(); - for (int j = 0; j < block_n; j++) + const float* p0 = P + i * block_n; + const float* p1 = P + (i + 1) * block_n; + for (int vb = 0; vb < v_num_blocks; vb++) + { + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + if (len == 32) { - float vscale = vscales[j * v_num_blocks + vb]; - __m256 vscale8 = _mm256_set1_ps(vscale); - const signed char* vptr = V + j * out_embed_dim + k_start; - __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); - __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); - __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); - __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); - __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); - __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); - __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); - __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); - __m256 pvec0 = _mm256_set1_ps(p0[j]); - __m256 pvec1 = _mm256_set1_ps(p1[j]); - acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); - acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); - acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); - acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); - acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); - acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); - acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); - acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + __m256 acc0_0 = _mm256_setzero_ps(); + __m256 acc0_1 = _mm256_setzero_ps(); + __m256 acc0_2 = _mm256_setzero_ps(); + __m256 acc0_3 = _mm256_setzero_ps(); + __m256 acc1_0 = _mm256_setzero_ps(); + __m256 acc1_1 = _mm256_setzero_ps(); + __m256 acc1_2 = _mm256_setzero_ps(); + __m256 acc1_3 = _mm256_setzero_ps(); + for (int j = 0; j < block_n; j++) + { + float vscale = vscales[j * v_num_blocks + vb]; + __m256 vscale8 = _mm256_set1_ps(vscale); + const signed char* vptr = V + j * out_embed_dim + k_start; + __m128i v8_0 = _mm_loadl_epi64((const __m128i*)(vptr + 0)); + __m128i v8_1 = _mm_loadl_epi64((const __m128i*)(vptr + 8)); + __m128i v8_2 = _mm_loadl_epi64((const __m128i*)(vptr + 16)); + __m128i v8_3 = _mm_loadl_epi64((const __m128i*)(vptr + 24)); + __m256 vval_0 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_0)), vscale8); + __m256 vval_1 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_1)), vscale8); + __m256 vval_2 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_2)), vscale8); + __m256 vval_3 = _mm256_mul_ps(_mm256_cvtepi32_ps(_mm256_cvtepi8_epi32(v8_3)), vscale8); + __m256 pvec0 = _mm256_set1_ps(p0[j]); + __m256 pvec1 = _mm256_set1_ps(p1[j]); + acc0_0 = _mm256_fmadd_ps(pvec0, vval_0, acc0_0); + acc0_1 = _mm256_fmadd_ps(pvec0, vval_1, acc0_1); + acc0_2 = _mm256_fmadd_ps(pvec0, vval_2, acc0_2); + acc0_3 = _mm256_fmadd_ps(pvec0, vval_3, acc0_3); + acc1_0 = _mm256_fmadd_ps(pvec1, vval_0, acc1_0); + acc1_1 = _mm256_fmadd_ps(pvec1, vval_1, acc1_1); + acc1_2 = _mm256_fmadd_ps(pvec1, vval_2, acc1_2); + acc1_3 = _mm256_fmadd_ps(pvec1, vval_3, acc1_3); + } + float* optr0 = O + i * out_embed_dim + k_start; + float* optr1 = O + (i + 1) * out_embed_dim + k_start; + _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); + _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); + _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); + _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); + _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); + _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); + _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); + _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); } - float* optr0 = O + i * out_embed_dim + k_start; - float* optr1 = O + (i + 1) * out_embed_dim + k_start; - _mm256_storeu_ps(optr0 + 0, _mm256_add_ps(_mm256_loadu_ps(optr0 + 0), acc0_0)); - _mm256_storeu_ps(optr0 + 8, _mm256_add_ps(_mm256_loadu_ps(optr0 + 8), acc0_1)); - _mm256_storeu_ps(optr0 + 16, _mm256_add_ps(_mm256_loadu_ps(optr0 + 16), acc0_2)); - _mm256_storeu_ps(optr0 + 24, _mm256_add_ps(_mm256_loadu_ps(optr0 + 24), acc0_3)); - _mm256_storeu_ps(optr1 + 0, _mm256_add_ps(_mm256_loadu_ps(optr1 + 0), acc1_0)); - _mm256_storeu_ps(optr1 + 8, _mm256_add_ps(_mm256_loadu_ps(optr1 + 8), acc1_1)); - _mm256_storeu_ps(optr1 + 16, _mm256_add_ps(_mm256_loadu_ps(optr1 + 16), acc1_2)); - _mm256_storeu_ps(optr1 + 24, _mm256_add_ps(_mm256_loadu_ps(optr1 + 24), acc1_3)); - } - else - { - for (int j = 0; j < block_n; j++) + else { - float vscale = vscales[j * v_num_blocks + vb]; - const signed char* vptr = V + j * out_embed_dim + k_start; - for (int k = 0; k < len; k++) + for (int j = 0; j < block_n; j++) { - float vval = (float)vptr[k] * vscale; - O[i * out_embed_dim + k_start + k] += p0[j] * vval; - O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + float vscale = vscales[j * v_num_blocks + vb]; + const signed char* vptr = V + j * out_embed_dim + k_start; + for (int k = 0; k < len; k++) + { + float vval = (float)vptr[k] * vscale; + O[i * out_embed_dim + k_start + k] += p0[j] * vval; + O[(i + 1) * out_embed_dim + k_start + k] += p1[j] * vval; + } } } } } -} -for (; i < block_m; i++) -{ - const float* p_row = P + i * block_n; - for (int j = 0; j < block_n; j++) + for (; i < block_m; i++) { - float p = p_row[j]; - const signed char* vptr = V + j * out_embed_dim; - const float* vscales_row = vscales + j * v_num_blocks; - for (int vb = 0; vb < v_num_blocks; vb++) + const float* p_row = P + i * block_n; + for (int j = 0; j < block_n; j++) { - int k_start = vb * 32; - int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; - int len = k_end - k_start; - if (len <= 0) continue; - float vscale = vscales_row[vb]; - for (int k = 0; k < len; k++) + float p = p_row[j]; + const signed char* vptr = V + j * out_embed_dim; + const float* vscales_row = vscales + j * v_num_blocks; + for (int vb = 0; vb < v_num_blocks; vb++) { - O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + int k_start = vb * 32; + int k_end = k_start + 32 < out_embed_dim ? k_start + 32 : out_embed_dim; + int len = k_end - k_start; + if (len <= 0) continue; + float vscale = vscales_row[vb]; + for (int k = 0; k < len; k++) + { + O[i * out_embed_dim + k_start + k] += p * (float)vptr[k_start + k] * vscale; + } } } } } -} #endif static void pv_float_int8_gemm_tile(float* O, const float* P, @@ -3324,5 +3324,3 @@ static void pv_float_int8_gemm_tile(float* O, const float* P, #if __SSE2__ && !__SSE4_1__ #undef _mm_cvtepi8_epi32 #endif - - From 458f55c7290fe8bdd511acff79c4755fdd77af07 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Sun, 3 May 2026 22:23:17 +0800 Subject: [PATCH 29/53] support perf int8 --- src/layer/x86/sdpa_x86.cpp | 229 ++- src/layer/x86/sdpa_x86.h | 6 +- src/layer/x86/sdpa_x86_avx512bf16.cpp | 42 + src/layer/x86/sdpa_x86_bf16s.h | 1935 +++++++++++++++++++++++++ 4 files changed, 2186 insertions(+), 26 deletions(-) create mode 100644 src/layer/x86/sdpa_x86_avx512bf16.cpp create mode 100644 src/layer/x86/sdpa_x86_bf16s.h diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 6f27a1d8d5c3..541155f0f90e 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -30,14 +30,12 @@ namespace ncnn { SDPA_x86::SDPA_x86() { #if NCNN_BF16 - support_bf16_storage = false; + support_bf16_storage = true; #endif -#if NCNN_INT8 cached_kv_seqlen = -1; cached_num_group = 0; cached_embed_dim = 0; cached_out_embed_dim = 0; -#endif } int SDPA_x86::create_pipeline(const Option& /*_opt*/) @@ -57,6 +55,10 @@ int SDPA_x86::destroy_pipeline(const Option& /*_opt*/) #include "sdpa_x86_int8.h" +#if NCNN_BF16 +#include "sdpa_x86_bf16s.h" +#endif + static inline void qk_gemm_scalar(float* S, const float* Q, const float* K, int m, int n, int d, float scale) { @@ -3619,6 +3621,31 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (top_blob.empty()) return -100; +#if NCNN_BF16 + bool use_bf16_path = opt.use_bf16_storage && query.elembits() == 16; + Mat query_fp32; + Mat attn_mask_fp32; + if (use_bf16_path) + { + cast_bfloat16_to_float32(query, query_fp32, opt); + if (query_fp32.empty()) + return -100; + if (attn_mask && !attn_mask_blob.empty()) + { + cast_bfloat16_to_float32(attn_mask_blob, attn_mask_fp32, opt); + if (attn_mask_fp32.empty()) + return -100; + } + } + const Mat& query_ref = use_bf16_path ? query_fp32 : query; + const Mat& attn_mask_ref = (use_bf16_path && attn_mask) ? attn_mask_fp32 : attn_mask_blob; +#else + const Mat& query_ref = query; + const Mat& attn_mask_ref = attn_mask_blob; + (void)query_ref; + (void)attn_mask_ref; +#endif + #if NCNN_INT8 bool use_int8_path = int8_scale_term; if (use_int8_path && src_seqlen == 1) @@ -3740,7 +3767,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const float* mask_ptr = nullptr; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; Mat mask_head; if (maskm.dims == 3) { @@ -3831,7 +3858,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const float* mask_ptr = nullptr; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; Mat mask_head; if (maskm.dims == 3) { @@ -3927,7 +3954,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat mask_head; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; if (maskm.dims == 3) { mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); @@ -4091,6 +4118,114 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const int BLOCK_N = 128; const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; +#if NCNN_BF16 + if (use_bf16_path) + { + if (use_split_kv) + { + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + + const float* qptr = query_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk_bf16s_dispatch(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce_bf16s_dispatch(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode_bf16s_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; + } +#endif // NCNN_BF16 + if (use_split_kv) { const int num_kv_chunks = opt.num_threads; @@ -4110,7 +4245,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to int g = q / num_heads_per_group; const Mat key_head = key.channel(g); const Mat value_head = value.channel(g); - const Mat query_head = query.channel(q); + const Mat query_head = query_ref.channel(q); const float* qptr = query_head.row(0); const float* Kptr = key_head; @@ -4119,7 +4254,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const float* mask_ptr = nullptr; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; Mat mask_head; if (maskm.dims == 3) { @@ -4158,7 +4293,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); + const Mat query_head = query_ref.channel(q); Mat top_blob_head = top_blob.channel(q); const float* qptr = query_head.row(0); @@ -4169,7 +4304,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to const float* mask_ptr = nullptr; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; Mat mask_head; if (maskm.dims == 3) { @@ -4238,7 +4373,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to mask_stride[hq] = 0; if (attn_mask) { - const Mat& maskm = attn_mask_blob; + const Mat& maskm = attn_mask_ref; Mat mh = (maskm.dims == 3 && maskm.c > 1) ? maskm.channel(q) : (maskm.dims == 3 ? maskm.channel(0) : maskm); mask_data[hq] = mh; @@ -4254,7 +4389,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); + const Mat query_head = query_ref.channel(q); float* q_dst = q_batch_thread.row(hq * block_m); for (int i = 0; i < block_m; i++) { @@ -4288,10 +4423,22 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (!large_dim) { - qk_gemm_dispatch(s_ptr, - q_batch_thread.row(0), - key_head.row(n_start), - block_m * num_heads_per_group, block_n, embed_dim, _scale); +#if NCNN_BF16 + if (use_bf16_path) + { + qk_gemm_bf16s_dispatch(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } + else +#endif + { + qk_gemm_dispatch(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } for (int hq = 0; hq < num_heads_per_group; hq++) { @@ -4356,15 +4503,25 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); +#if NCNN_BF16 + if (use_bf16_path) + { + pv_gemm_bf16s_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } + else +#endif + { + pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } } else { for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); + const Mat query_head = query_ref.channel(q); float* q_dst = q_batch_thread.row(0); for (int i = 0; i < block_m; i++) @@ -4374,10 +4531,22 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* s_head = s_ptr; - qk_gemm_dispatch(s_head, - q_dst, - key_head.row(n_start), - block_m, block_n, embed_dim, _scale); +#if NCNN_BF16 + if (use_bf16_path) + { + qk_gemm_bf16s_dispatch(s_head, + q_dst, + key_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } + else +#endif + { + qk_gemm_dispatch(s_head, + q_dst, + key_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } if (attn_mask && mask_data[hq]) { @@ -4437,8 +4606,18 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start), - block_m, block_n, out_embed_dim); +#if NCNN_BF16 + if (use_bf16_path) + { + pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start), + block_m, block_n, out_embed_dim); + } + else +#endif + { + pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start), + block_m, block_n, out_embed_dim); + } } } } diff --git a/src/layer/x86/sdpa_x86.h b/src/layer/x86/sdpa_x86.h index 08ac3f3966bf..5e10432273e7 100644 --- a/src/layer/x86/sdpa_x86.h +++ b/src/layer/x86/sdpa_x86.h @@ -19,16 +19,20 @@ class SDPA_x86 : public SDPA virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; private: +#if NCNN_BF16 + mutable Mat cached_key_bf16; + mutable Mat cached_value_bf16; +#endif #if NCNN_INT8 mutable Mat cached_key_int8; mutable Mat cached_key_scales; mutable Mat cached_value_int8; mutable Mat cached_value_scales; +#endif mutable int cached_kv_seqlen; mutable int cached_num_group; mutable int cached_embed_dim; mutable int cached_out_embed_dim; -#endif }; } // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_avx512bf16.cpp b/src/layer/x86/sdpa_x86_avx512bf16.cpp new file mode 100644 index 000000000000..45002da2ec71 --- /dev/null +++ b/src/layer/x86/sdpa_x86_avx512bf16.cpp @@ -0,0 +1,42 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#include "cpu.h" +#include "mat.h" +#if __SSE2__ +#include +#if __AVX__ +#include +#endif // __AVX__ +#endif // __SSE2__ +#include "x86_usability.h" + +namespace ncnn { + +#include "sdpa_x86_bf16s.h" + +#if __AVX512BF16__ + +void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); +} + +void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); +} + +void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) +{ + qk_gemm_bf16s_avx512(S, Q, K, m, n, d, scale); +} + +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + pv_gemm_bf16s_avx512(O, P, V, m, n, d); +} + +#endif // __AVX512BF16__ + +} // namespace ncnn diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h new file mode 100644 index 000000000000..f37c27591ee2 --- /dev/null +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -0,0 +1,1935 @@ +// Copyright 2026 Tencent +// SPDX-License-Identifier: BSD-3-Clause + +#ifndef SDPA_X86_BF16S_H +#define SDPA_X86_BF16S_H + +#include + +#if __SSE2__ +#include +#include "sse_mathfun.h" +#if __AVX__ +#include +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ +#endif // __SSE2__ + +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && __AVX512F__ && !__AVX512BF16__ +void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale); +void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d); +void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale); +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d); +#endif + +// --------------------------------------------------------------------------- +// decode_qk_dot_bf16s : Q(fp32) dot K(bf16) -> S(fp32) +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static inline void decode_qk_dot_bf16s_avx512_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + if (d >= 256) + { + for (; j + 7 < block_n; j += 8) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + const unsigned short* k2 = K + (n_start + j + 2) * d; + const unsigned short* k3 = K + (n_start + j + 3) * d; + const unsigned short* k4 = K + (n_start + j + 4) * d; + const unsigned short* k5 = K + (n_start + j + 5) * d; + const unsigned short* k6 = K + (n_start + j + 6) * d; + const unsigned short* k7 = K + (n_start + j + 7) * d; + + if (j + 15 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 8) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 9) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 10) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 11) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 12) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 13) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 14) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 15) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + __m512 acc4 = _mm512_setzero_ps(); + __m512 acc5 = _mm512_setzero_ps(); + __m512 acc6 = _mm512_setzero_ps(); + __m512 acc7 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + acc4 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k4 + k))), acc4); + acc5 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k5 + k))), acc5); + acc6 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k6 + k))), acc6); + acc7 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k7 + k))), acc7); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + acc4 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k4 + k)), acc4); + acc5 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k5 + k)), acc5); + acc6 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k6 + k)), acc6); + acc7 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k7 + k)), acc7); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + s[j + 4] = _mm512_comp_reduce_add_ps(acc4) * scale; + s[j + 5] = _mm512_comp_reduce_add_ps(acc5) * scale; + s[j + 6] = _mm512_comp_reduce_add_ps(acc6) * scale; + s[j + 7] = _mm512_comp_reduce_add_ps(acc7) * scale; + } + } + + for (; j + 3 < block_n; j += 4) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + const unsigned short* k2 = K + (n_start + j + 2) * d; + const unsigned short* k3 = K + (n_start + j + 3) * d; + + if (j + 7 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 5) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 6) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 7) * d), _MM_HINT_T1); + } + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qv = _mm512_loadu_ps(q + k); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + } + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __m512 qv = _mm512_maskz_loadu_ps(mask_d, q + k); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc0 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qv, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + } + + s[j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + s[j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + s[j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + s[j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m512 acc = _mm512_setzero_ps(); + int k = 0; + const unsigned short* kptr = K + (n_start + j) * d; + for (; k + 15 < d; k += 16) + acc = _mm512_fmadd_ps(_mm512_loadu_ps(q + k), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))), acc); + if (k < d) + { + __mmask16 mask_d = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + acc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask_d, q + k), bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)), acc); + } + s[j] = _mm512_comp_reduce_add_ps(acc) * scale; + } +} +#endif // __AVX512F__ + +#if __AVX__ +static inline void decode_qk_dot_bf16s_avx_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + const unsigned short* k0 = K + (n_start + j + 0) * d; + const unsigned short* k1 = K + (n_start + j + 1) * d; + + if (j + 5 < block_n) + { + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (n_start + j + 3) * d), _MM_HINT_T1); + } + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qv = _mm256_loadu_ps(q + k); + acc0 = _mm256_comp_fmadd_ps(qv, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))), acc0); + acc1 = _mm256_comp_fmadd_ps(qv, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + + for (; k < d; k++) + { + sum0 += q[k] * bfloat16_to_float32(k0[k]); + sum1 += q[k] * bfloat16_to_float32(k1[k]); + } + + s[j + 0] = sum0 * scale; + s[j + 1] = sum1 * scale; + } + + for (; j < block_n; j++) + { + if (j + 2 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 2) * d), _MM_HINT_T1); + + const unsigned short* kptr = K + (n_start + j) * d; + __m256 acc = _mm256_setzero_ps(); + int k = 0; + for (; k + 7 < d; k += 8) + acc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(q + k), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))), acc); + float sum = _mm256_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} +#endif // __AVX__ + +#if __SSE2__ +static inline void decode_qk_dot_bf16s_sse2_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(K + (n_start + j + 4) * d), _MM_HINT_T1); + + __m128 acc = _mm_setzero_ps(); + int k = 0; + const unsigned short* kptr = K + (n_start + j) * d; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(q + k), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr + k))), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} +#endif // __SSE2__ + +static inline void decode_qk_dot_bf16s_scalar_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + for (int j = 0; j < block_n; j++) + { + float sum = 0.f; + const unsigned short* kptr = K + (n_start + j) * d; + for (int k = 0; k < d; k++) + sum += q[k] * bfloat16_to_float32(kptr[k]); + s[j] = sum * scale; + } +} + + +// --------------------------------------------------------------------------- +// decode_pv_gemv_bf16s : S(fp32) gemv V(bf16) -> out(fp32) +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static inline void decode_pv_gemv_bf16s_avx512_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int j = 0; + for (; j + 1 < block_n; j += 2) + { + if (j + 6 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 6) * out_d), _MM_HINT_T1); + + __m512 pvec0 = _mm512_set1_ps(s[j]); + __m512 pvec1 = _mm512_set1_ps(s[j + 1]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v00 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v01 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k + 16))); + __m512 v10 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k))); + __m512 v11 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(pvec0, v00, oval0); + oval1 = _mm512_fmadd_ps(pvec0, v01, oval1); + oval0 = _mm512_fmadd_ps(pvec1, v10, oval0); + oval1 = _mm512_fmadd_ps(pvec1, v11, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j + 1) * out_d + k))); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_storeu_ps(out + k, oval); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 v0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j) * out_d + k)); + __m512 v1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j + 1) * out_d + k)); + oval = _mm512_fmadd_ps(pvec0, v0, oval); + oval = _mm512_fmadd_ps(pvec1, v1, oval); + _mm512_mask_storeu_ps(out + k, mask_d, oval); + } + } + for (; j < block_n; j++) + { + __m512 pvec512 = _mm512_set1_ps(s[j]); + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(pvec512, v0, oval0); + oval1 = _mm512_fmadd_ps(pvec512, v1, oval1); + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval = _mm512_loadu_ps(out + k); + __m512 vval = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + (n_start + j) * out_d + k))); + _mm512_storeu_ps(out + k, _mm512_fmadd_ps(pvec512, vval, oval)); + k += 16; + } + if (k < out_d) + { + __mmask16 mask_d = (__mmask16)((1u << (out_d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (out_d - k)) - 1); + __m512 oval = _mm512_maskz_loadu_ps(mask_d, out + k); + __m512 vval = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, V + (n_start + j) * out_d + k)); + _mm512_mask_storeu_ps(out + k, mask_d, _mm512_fmadd_ps(pvec512, vval, oval)); + } + } +} +#endif // __AVX512F__ + +#if __AVX__ +static inline void decode_pv_gemv_bf16s_avx_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int j = 0; + for (; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + + int k = 0; + __m256 pvec256 = _mm256_set1_ps(s[j]); + for (; k + 7 < out_d; k += 8) + { + __m256 oval = _mm256_loadu_ps(out + k); + __m256 vval = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + (n_start + j) * out_d + k))); + _mm256_storeu_ps(out + k, _mm256_comp_fmadd_ps(pvec256, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * bfloat16_to_float32(V[(n_start + j) * out_d + k]); + } +} +#endif // __AVX__ + +#if __SSE2__ +static inline void decode_pv_gemv_bf16s_sse2_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + if (j + 4 < block_n) + _mm_prefetch((const char*)(V + (n_start + j + 4) * out_d), _MM_HINT_T1); + + __m128 pvec128 = _mm_set1_ps(s[j]); + int k = 0; + for (; k + 3 < out_d; k += 4) + { + __m128 oval = _mm_loadu_ps(out + k); + __m128 vval = bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + (n_start + j) * out_d + k))); + _mm_storeu_ps(out + k, _mm_comp_fmadd_ps(pvec128, vval, oval)); + } + for (; k < out_d; k++) + out[k] += s[j] * bfloat16_to_float32(V[(n_start + j) * out_d + k]); + } +} +#endif // __SSE2__ + +static inline void decode_pv_gemv_bf16s_scalar_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + for (int j = 0; j < block_n; j++) + { + float p = s[j]; + const unsigned short* vptr = V + (n_start + j) * out_d; + for (int k = 0; k < out_d; k++) + out[k] += p * bfloat16_to_float32(vptr[k]); + } +} + + +// --------------------------------------------------------------------------- +// sdpa_decode_bf16s : full decode with bf16 K/V +// --------------------------------------------------------------------------- + +static inline void sdpa_decode_bf16s(float* out, const float* q, + const unsigned short* K, const unsigned short* V, const float* mask, + int n, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + // vec_zero + { +#if __AVX512F__ + __m512 zero512 = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, zero512); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, zero512); + } +#else + int i = 0; +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, zero256); +#endif +#if __SSE2__ + __m128 zero128 = _mm_setzero_ps(); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, zero128); +#endif + for (; i < out_d; i++) + out[i] = 0.f; +#endif + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < n; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n - n_start); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, q, K, n_start, block_n, d, scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, q, K, n_start, block_n, d, scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, q, K, n_start, block_n, d, scale); +#endif + + if (mask) + { +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); + for (; j < block_n; j++) + s[j] += mask[n_start + j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[n_start + j]; +#endif + } + + // tile max +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float tile_m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#else + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#endif + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + // vec_scale(out, scale_factor, out_d); + { +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(scale_factor); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); + } +#else + int i = 0; +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(scale_factor); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); +#endif +#if __SSE2__ + __m128 vscale128 = _mm_set1_ps(scale_factor); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); +#endif + for (; i < out_d; i++) + out[i] *= scale_factor; +#endif + } + } + + // exp and sum +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#else + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#endif + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(out, s, V, n_start, block_n, out_d); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n_start, block_n, out_d); +#else + decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n_start, block_n, out_d); +#endif + + m = new_m; + } + + float inv_l = 1.f / l; + // vec_scale(out, inv_l, out_d); + { +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(inv_l); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); + } +#else + int i = 0; +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(inv_l); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); +#endif +#if __SSE2__ + __m128 vscale128 = _mm_set1_ps(inv_l); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); +#endif + for (; i < out_d; i++) + out[i] *= inv_l; +#endif + } +} + + +// --------------------------------------------------------------------------- +// sdpa_decode_chunk_bf16s / sdpa_decode_reduce_bf16s +// --------------------------------------------------------------------------- + +static inline void sdpa_decode_chunk_bf16s( + float* out, float* m_out, float* l_out, + const float* q, const unsigned short* K, const unsigned short* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale) +{ + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + // vec_zero + { +#if __AVX512F__ + __m512 zero512 = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, zero512); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, zero512); + } +#else + int i = 0; +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, zero256); +#endif +#if __SSE2__ + __m128 zero128 = _mm_setzero_ps(); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, zero128); +#endif + for (; i < out_d; i++) + out[i] = 0.f; +#endif + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n = n_start; n < n_end; n += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n_end - n); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, q, K, n, block_n, d, scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, q, K, n, block_n, d, scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, q, K, n, block_n, d, scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, q, K, n, block_n, d, scale); +#endif + + if (mask) + { +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n + j))); + for (; j < block_n; j++) + s[j] += mask[n + j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n + j))); + for (; j < block_n; j++) + s[j] += mask[n + j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[n + j]; +#endif + } + + // tile max +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); + } + float tile_m = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float tile_m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float tile_m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#else + float tile_m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + tile_m = std::max(tile_m, s[j]); +#endif + + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + // vec_scale + { +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(scale_factor); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); + } +#else + int i = 0; +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(scale_factor); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); +#endif +#if __SSE2__ + __m128 vscale128 = _mm_set1_ps(scale_factor); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); +#endif + for (; i < out_d; i++) + out[i] *= scale_factor; +#endif + } + } + + // exp and sum +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + l += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l_add = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l_add = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#else + float l_add = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l_add += s[j]; + } + l += l_add; +#endif + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n, block_n, out_d); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(out, s, V, n, block_n, out_d); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n, block_n, out_d); +#else + decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n, block_n, out_d); +#endif + + m = new_m; + } + + *m_out = m; + *l_out = l; +} + +static inline void sdpa_decode_reduce_bf16s( + float* out, int out_d, + const float* partials, int num_chunks, int partial_stride) +{ + float M_final = -FLT_MAX; + float S_final = 0.f; + // vec_zero + { +#if __AVX512F__ + __m512 zero512 = _mm512_setzero_ps(); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, zero512); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, zero512); + } +#else + int i = 0; +#if __AVX__ + __m256 zero256 = _mm256_setzero_ps(); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, zero256); +#endif +#if __SSE2__ + __m128 zero128 = _mm_setzero_ps(); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, zero128); +#endif + for (; i < out_d; i++) + out[i] = 0.f; +#endif + } + + for (int c = 0; c < num_chunks; c++) + { + const float* p = partials + c * partial_stride; + float M_chunk = p[0]; + float S_chunk = p[1]; + if (S_chunk == 0.f) continue; + + const float* VKQ_chunk = p + 2; + + float M_new = std::max(M_final, M_chunk); + float scale_final = expf(M_final - M_new); + float scale_chunk = expf(M_chunk - M_new); + + for (int k = 0; k < out_d; k++) + { + out[k] = out[k] * scale_final + VKQ_chunk[k] * scale_chunk; + } + + S_final = S_final * scale_final + S_chunk * scale_chunk; + M_final = M_new; + } + + if (S_final != 0.f) + { + float inv_s = 1.f / S_final; + // vec_scale + { +#if __AVX512F__ + __m512 vscale512 = _mm512_set1_ps(inv_s); + int i = 0; + for (; i + 15 < out_d; i += 16) + _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); + if (i < out_d) + { + __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); + _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); + } +#else + int i = 0; +#if __AVX__ + __m256 vscale256 = _mm256_set1_ps(inv_s); + for (; i + 7 < out_d; i += 8) + _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); +#endif +#if __SSE2__ + __m128 vscale128 = _mm_set1_ps(inv_s); + for (; i + 3 < out_d; i += 4) + _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); +#endif + for (; i < out_d; i++) + out[i] *= inv_s; +#endif + } + } +} + + +// --------------------------------------------------------------------------- +// qk_gemm_bf16s : Q(fp32) x K^T(bf16) -> S(fp32) [prefill] +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static void qk_gemm_bf16s_avx512(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))); + __m512 kv1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)); + __m512 kv1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)); + + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kv0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))); + __m512 kv1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kv0 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)); + __m512 kv1 = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)); + + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = _mm512_comp_reduce_add_ps(acc[mi][0]) * scale; + S[(i + mi) * n + j + 1] = _mm512_comp_reduce_add_ps(acc[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 kvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 kvec = bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = _mm512_comp_reduce_add_ps(acc[mi]) * scale; + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 3 < n; j += 4) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + const unsigned short* k2 = K + (j + 2) * d; + const unsigned short* k3 = K + (j + 3) * d; + + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + __m512 acc2 = _mm512_setzero_ps(); + __m512 acc3 = _mm512_setzero_ps(); + + int k = 0; + for (; k + 15 < d; k += 16) + { + __m512 qvec = _mm512_loadu_ps(qptr + k); + acc0 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k0 + k))), acc0); + acc1 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k1 + k))), acc1); + acc2 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k2 + k))), acc2); + acc3 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(k3 + k))), acc3); + } + + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k0 + k)), acc0); + acc1 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k1 + k)), acc1); + acc2 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k2 + k)), acc2); + acc3 = _mm512_fmadd_ps(qvec, bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, k3 + k)), acc3); + } + + S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; + S[i * n + j + 1] = _mm512_comp_reduce_add_ps(acc1) * scale; + S[i * n + j + 2] = _mm512_comp_reduce_add_ps(acc2) * scale; + S[i * n + j + 3] = _mm512_comp_reduce_add_ps(acc3) * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + int k = 0; + __m512 vacc = _mm512_setzero_ps(); + for (; k + 15 < d; k += 16) + vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(kptr + k))), vacc); + if (k < d) + { + __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); + __mmask16 mask16 = (__mmask16)((1u << (d - k)) - 1); + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), bfloat2float_avx512(_mm256_maskz_loadu_epi16(mask16, kptr + k)), vacc); + } + S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; + } + } +} +#endif // __AVX512F__ + + +#if __AVX__ +static void qk_gemm_bf16s_avx(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 6 <= m; i += 6) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m256 acc[6][2]; + for (int mi = 0; mi < 6; mi++) + { + acc[mi][0] = _mm256_setzero_ps(); + acc[mi][1] = _mm256_setzero_ps(); + } + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kv0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))); + __m256 kv1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))); + + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi][0] = _mm256_comp_fmadd_ps(qvec, kv0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(qvec, kv1, acc[mi][1]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum0 = _mm256_reduce_add_ps(acc[mi][0]); + float sum1 = _mm256_reduce_add_ps(acc[mi][1]); + for (; k < d; k++) + { + float qv = Q[(i + mi) * d + k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[(i + mi) * n + j + 0] = sum0 * scale; + S[(i + mi) * n + j + 1] = sum1 * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m256 acc[6]; + for (int mi = 0; mi < 6; mi++) + acc[mi] = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 kvec = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))); + for (int mi = 0; mi < 6; mi++) + { + __m256 qvec = _mm256_loadu_ps(Q + (i + mi) * d + k); + acc[mi] = _mm256_comp_fmadd_ps(qvec, kvec, acc[mi]); + } + } + + for (int mi = 0; mi < 6; mi++) + { + float sum = _mm256_reduce_add_ps(acc[mi]); + for (; k < d; k++) + sum += Q[(i + mi) * d + k] * bfloat16_to_float32(kptr[k]); + S[(i + mi) * n + j] = sum * scale; + } + } + } + + for (; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m256 acc0 = _mm256_setzero_ps(); + __m256 acc1 = _mm256_setzero_ps(); + + int k = 0; + for (; k + 7 < d; k += 8) + { + __m256 qvec = _mm256_loadu_ps(qptr + k); + acc0 = _mm256_comp_fmadd_ps(qvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k0 + k))), acc0); + acc1 = _mm256_comp_fmadd_ps(qvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(k1 + k))), acc1); + } + + float sum0 = _mm256_reduce_add_ps(acc0); + float sum1 = _mm256_reduce_add_ps(acc1); + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + float sum = 0.f; + int k = 0; + __m256 vacc = _mm256_setzero_ps(); + for (; k + 7 < d; k += 8) + vacc = _mm256_comp_fmadd_ps(_mm256_loadu_ps(qptr + k), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(kptr + k))), vacc); + sum = _mm256_reduce_add_ps(vacc); + for (; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} +#endif // __AVX__ + +#if __SSE2__ +static void qk_gemm_bf16s_sse2(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + int j = 0; + for (; j + 1 < n; j += 2) + { + const float* qptr = Q + i * d; + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m128 acc0 = _mm_setzero_ps(); + __m128 acc1 = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + { + __m128 qvec = _mm_loadu_ps(qptr + k); + acc0 = _mm_comp_fmadd_ps(qvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(k0 + k))), acc0); + acc1 = _mm_comp_fmadd_ps(qvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(k1 + k))), acc1); + } + float sum0 = _mm_reduce_add_ps(acc0); + float sum1 = _mm_reduce_add_ps(acc1); + for (; k < d; k++) + { + float qv = qptr[k]; + sum0 += qv * bfloat16_to_float32(k0[k]); + sum1 += qv * bfloat16_to_float32(k1[k]); + } + S[i * n + j + 0] = sum0 * scale; + S[i * n + j + 1] = sum1 * scale; + } + + for (; j < n; j++) + { + const float* qptr = Q + i * d; + const unsigned short* kptr = K + j * d; + __m128 acc = _mm_setzero_ps(); + int k = 0; + for (; k + 3 < d; k += 4) + acc = _mm_comp_fmadd_ps(_mm_loadu_ps(qptr + k), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(kptr + k))), acc); + float sum = _mm_reduce_add_ps(acc); + for (; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} +#endif // __SSE2__ + +static void qk_gemm_bf16s_scalar(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + for (int j = 0; j < n; j++) + { + const unsigned short* kptr = K + j * d; + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += qptr[k] * bfloat16_to_float32(kptr[k]); + S[i * n + j] = sum * scale; + } + } +} + + +// --------------------------------------------------------------------------- +// pv_gemm_bf16s : P(fp32) x V(bf16) -> O(fp32) [prefill] +// --------------------------------------------------------------------------- + +#if __AVX512F__ +static void pv_gemm_bf16s_avx512(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 127 < d; dd += 128) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + const float* pptr[4]; + for (int mi = 0; mi < 4; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m512 acc[4][8]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + 0 * 16); + acc[mi][1] = _mm512_loadu_ps(op[mi] + 1 * 16); + acc[mi][2] = _mm512_loadu_ps(op[mi] + 2 * 16); + acc[mi][3] = _mm512_loadu_ps(op[mi] + 3 * 16); + acc[mi][4] = _mm512_loadu_ps(op[mi] + 4 * 16); + acc[mi][5] = _mm512_loadu_ps(op[mi] + 5 * 16); + acc[mi][6] = _mm512_loadu_ps(op[mi] + 6 * 16); + acc[mi][7] = _mm512_loadu_ps(op[mi] + 7 * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 v0 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 0 * 16))); + __m512 v1 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 1 * 16))); + __m512 v2 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 2 * 16))); + __m512 v3 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 3 * 16))); + __m512 v4 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 4 * 16))); + __m512 v5 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 5 * 16))); + __m512 v6 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 6 * 16))); + __m512 v7 = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 7 * 16))); + + for (int mi = 0; mi < 4; mi++) + { + __m512 pvec = _mm512_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm512_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm512_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm512_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm512_fmadd_ps(pvec, v3, acc[mi][3]); + acc[mi][4] = _mm512_fmadd_ps(pvec, v4, acc[mi][4]); + acc[mi][5] = _mm512_fmadd_ps(pvec, v5, acc[mi][5]); + acc[mi][6] = _mm512_fmadd_ps(pvec, v6, acc[mi][6]); + acc[mi][7] = _mm512_fmadd_ps(pvec, v7, acc[mi][7]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0 * 16, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 1 * 16, acc[mi][1]); + _mm512_storeu_ps(op[mi] + 2 * 16, acc[mi][2]); + _mm512_storeu_ps(op[mi] + 3 * 16, acc[mi][3]); + _mm512_storeu_ps(op[mi] + 4 * 16, acc[mi][4]); + _mm512_storeu_ps(op[mi] + 5 * 16, acc[mi][5]); + _mm512_storeu_ps(op[mi] + 6 * 16, acc[mi][6]); + _mm512_storeu_ps(op[mi] + 7 * 16, acc[mi][7]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc0 = _mm512_loadu_ps(optr + 0 * 16); + __m512 acc1 = _mm512_loadu_ps(optr + 1 * 16); + __m512 acc2 = _mm512_loadu_ps(optr + 2 * 16); + __m512 acc3 = _mm512_loadu_ps(optr + 3 * 16); + __m512 acc4 = _mm512_loadu_ps(optr + 4 * 16); + __m512 acc5 = _mm512_loadu_ps(optr + 5 * 16); + __m512 acc6 = _mm512_loadu_ps(optr + 6 * 16); + __m512 acc7 = _mm512_loadu_ps(optr + 7 * 16); + + for (int j = 0; j < n; j++) + { + __m512 pvec = _mm512_set1_ps(pptr[j]); + acc0 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 0 * 16))), acc0); + acc1 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 1 * 16))), acc1); + acc2 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 2 * 16))), acc2); + acc3 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 3 * 16))), acc3); + acc4 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 4 * 16))), acc4); + acc5 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 5 * 16))), acc5); + acc6 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 6 * 16))), acc6); + acc7 = _mm512_fmadd_ps(pvec, bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd + 7 * 16))), acc7); + } + + _mm512_storeu_ps(optr + 0 * 16, acc0); + _mm512_storeu_ps(optr + 1 * 16, acc1); + _mm512_storeu_ps(optr + 2 * 16, acc2); + _mm512_storeu_ps(optr + 3 * 16, acc3); + _mm512_storeu_ps(optr + 4 * 16, acc4); + _mm512_storeu_ps(optr + 5 * 16, acc5); + _mm512_storeu_ps(optr + 6 * 16, acc6); + _mm512_storeu_ps(optr + 7 * 16, acc7); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + const float* pptr[4]; + for (int mi = 0; mi < 4; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m512 vvec = bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd))); + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_fmadd_ps(_mm512_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), bfloat2float_avx512(_mm256_loadu_si256((const __m256i*)(V + j * d + dd))), acc); + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __AVX512F__ + +#if __AVX__ +static void pv_gemm_bf16s_avx(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 31 < d; dd += 32) + { + int i = 0; + for (; i + 2 <= m; i += 2) + { + float* op[2]; + const float* pptr[2]; + for (int mi = 0; mi < 2; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + + __m256 acc[2][4]; + for (int mi = 0; mi < 2; mi++) + { + acc[mi][0] = _mm256_loadu_ps(op[mi] + 0 * 8); + acc[mi][1] = _mm256_loadu_ps(op[mi] + 1 * 8); + acc[mi][2] = _mm256_loadu_ps(op[mi] + 2 * 8); + acc[mi][3] = _mm256_loadu_ps(op[mi] + 3 * 8); + } + + for (int j = 0; j < n; j++) + { + __m256 v0 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 0 * 8))); + __m256 v1 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 1 * 8))); + __m256 v2 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 2 * 8))); + __m256 v3 = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 3 * 8))); + + for (int mi = 0; mi < 2; mi++) + { + __m256 pvec = _mm256_set1_ps(pptr[mi][j]); + acc[mi][0] = _mm256_comp_fmadd_ps(pvec, v0, acc[mi][0]); + acc[mi][1] = _mm256_comp_fmadd_ps(pvec, v1, acc[mi][1]); + acc[mi][2] = _mm256_comp_fmadd_ps(pvec, v2, acc[mi][2]); + acc[mi][3] = _mm256_comp_fmadd_ps(pvec, v3, acc[mi][3]); + } + } + + for (int mi = 0; mi < 2; mi++) + { + _mm256_storeu_ps(op[mi] + 0 * 8, acc[mi][0]); + _mm256_storeu_ps(op[mi] + 1 * 8, acc[mi][1]); + _mm256_storeu_ps(op[mi] + 2 * 8, acc[mi][2]); + _mm256_storeu_ps(op[mi] + 3 * 8, acc[mi][3]); + } + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m256 acc0 = _mm256_loadu_ps(optr + 0 * 8); + __m256 acc1 = _mm256_loadu_ps(optr + 1 * 8); + __m256 acc2 = _mm256_loadu_ps(optr + 2 * 8); + __m256 acc3 = _mm256_loadu_ps(optr + 3 * 8); + + for (int j = 0; j < n; j++) + { + __m256 pvec = _mm256_set1_ps(pptr[j]); + acc0 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 0 * 8))), acc0); + acc1 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 1 * 8))), acc1); + acc2 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 2 * 8))), acc2); + acc3 = _mm256_comp_fmadd_ps(pvec, bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd + 3 * 8))), acc3); + } + + _mm256_storeu_ps(optr + 0 * 8, acc0); + _mm256_storeu_ps(optr + 1 * 8, acc1); + _mm256_storeu_ps(optr + 2 * 8, acc2); + _mm256_storeu_ps(optr + 3 * 8, acc3); + } + } + + for (; dd + 7 < d; dd += 8) + { + int i = 0; + for (; i + 2 <= m; i += 2) + { + float* op[2]; + const float* pptr[2]; + for (int mi = 0; mi < 2; mi++) + { + op[mi] = O + (i + mi) * d + dd; + pptr[mi] = P + (i + mi) * n; + } + __m256 acc[2]; + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256 vvec = bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd))); + for (int mi = 0; mi < 2; mi++) + acc[mi] = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[mi][j]), vvec, acc[mi]); + } + for (int mi = 0; mi < 2; mi++) + _mm256_storeu_ps(op[mi], acc[mi]); + } + + for (; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m256 acc = _mm256_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm256_comp_fmadd_ps(_mm256_set1_ps(pptr[j]), bfloat2float_avx(_mm_loadu_si128((const __m128i*)(V + j * d + dd))), acc); + _mm256_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __AVX__ + +#if __SSE2__ +static void pv_gemm_bf16s_sse2(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + int dd = 0; + for (; dd + 15 < d; dd += 16) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m128 acc0 = _mm_loadu_ps(optr + 0 * 4); + __m128 acc1 = _mm_loadu_ps(optr + 1 * 4); + __m128 acc2 = _mm_loadu_ps(optr + 2 * 4); + __m128 acc3 = _mm_loadu_ps(optr + 3 * 4); + + for (int j = 0; j < n; j++) + { + __m128 pvec = _mm_set1_ps(pptr[j]); + acc0 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 0 * 4))), acc0); + acc1 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 1 * 4))), acc1); + acc2 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 2 * 4))), acc2); + acc3 = _mm_comp_fmadd_ps(pvec, bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd + 3 * 4))), acc3); + } + + _mm_storeu_ps(optr + 0 * 4, acc0); + _mm_storeu_ps(optr + 1 * 4, acc1); + _mm_storeu_ps(optr + 2 * 4, acc2); + _mm_storeu_ps(optr + 3 * 4, acc3); + } + } + + for (; dd + 3 < d; dd += 4) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + __m128 acc = _mm_loadu_ps(optr); + for (int j = 0; j < n; j++) + acc = _mm_comp_fmadd_ps(_mm_set1_ps(pptr[j]), bfloat2float_sse(_mm_loadl_epi64((const __m128i*)(V + j * d + dd))), acc); + _mm_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + const float* pptr = P + i * n; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += pptr[j] * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } +} +#endif // __SSE2__ + +static void pv_gemm_bf16s_scalar(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + for (int i = 0; i < m; i++) + { + float* optr = O + i * d; + const float* pptr = P + i * n; + for (int j = 0; j < n; j++) + { + float p = pptr[j]; + const unsigned short* vptr = V + j * d; + for (int k = 0; k < d; k++) + optr[k] += p * bfloat16_to_float32(vptr[k]); + } + } +} + + +// --------------------------------------------------------------------------- +// Dispatch wrappers +// --------------------------------------------------------------------------- + +static inline void decode_qk_dot_bf16s(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + decode_qk_dot_bf16s_avx512bf16(s, q, K, n_start, block_n, d, scale); + return; + } +#endif + decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, q, K, n_start, block_n, d, scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, q, K, n_start, block_n, d, scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, q, K, n_start, block_n, d, scale); +#endif +} + +static inline void decode_pv_gemv_bf16s(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + decode_pv_gemv_bf16s_avx512bf16(out, s, V, n_start, block_n, out_d); + return; + } +#endif + decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(out, s, V, n_start, block_n, out_d); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n_start, block_n, out_d); +#else + decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n_start, block_n, out_d); +#endif +} + +static inline void qk_gemm_bf16s_dispatch(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(S, Q, K, m, n, d, scale); + return; + } +#endif + qk_gemm_bf16s_avx512(S, Q, K, m, n, d, scale); +#elif __AVX__ + qk_gemm_bf16s_avx(S, Q, K, m, n, d, scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2(S, Q, K, m, n, d, scale); +#else + qk_gemm_bf16s_scalar(S, Q, K, m, n, d, scale); +#endif +} + +static inline void pv_gemm_bf16s_dispatch(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(O, P, V, m, n, d); + return; + } +#endif + pv_gemm_bf16s_avx512(O, P, V, m, n, d); +#elif __AVX__ + pv_gemm_bf16s_avx(O, P, V, m, n, d); +#elif __SSE2__ + pv_gemm_bf16s_sse2(O, P, V, m, n, d); +#else + pv_gemm_bf16s_scalar(O, P, V, m, n, d); +#endif +} + +static inline void sdpa_decode_bf16s_dispatch(float* out, const float* q, + const unsigned short* K, const unsigned short* V, const float* mask, + int n, int d, int out_d, float scale) +{ + sdpa_decode_bf16s(out, q, K, V, mask, n, d, out_d, scale); +} + +static inline void sdpa_decode_chunk_bf16s_dispatch( + float* out, float* m_out, float* l_out, + const float* q, const unsigned short* K, const unsigned short* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale) +{ + sdpa_decode_chunk_bf16s(out, m_out, l_out, q, K, V, mask, n_start, n_end, d, out_d, scale); +} + +static inline void sdpa_decode_reduce_bf16s_dispatch( + float* out, int out_d, + const float* partials, int num_chunks, int partial_stride) +{ + sdpa_decode_reduce_bf16s(out, out_d, partials, num_chunks, partial_stride); +} + +#endif // SDPA_X86_BF16S_H From 702a664b0235b27d69deb96b2e4cdcc63155ef61 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 08:02:07 +0800 Subject: [PATCH 30/53] opt bf16s --- src/layer/x86/sdpa_x86_avx512bf16.cpp | 71 +- src/layer/x86/sdpa_x86_bf16s.h | 932 +++++++++++++++++++++++++- 2 files changed, 998 insertions(+), 5 deletions(-) diff --git a/src/layer/x86/sdpa_x86_avx512bf16.cpp b/src/layer/x86/sdpa_x86_avx512bf16.cpp index 45002da2ec71..298d68bcef07 100644 --- a/src/layer/x86/sdpa_x86_avx512bf16.cpp +++ b/src/layer/x86/sdpa_x86_avx512bf16.cpp @@ -19,22 +19,85 @@ namespace ncnn { void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) { - decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); + decode_qk_dot_bf16s_avx512bf16_kernel(s, q, K, n_start, block_n, d, scale); } void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) { - decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); + decode_pv_gemv_bf16s_avx512bf16_kernel(out, s, V, n_start, block_n, out_d); } +// Explicit instantiations for common embed_dim values +// These allow GCC to fully unroll the inner d-loops at compile time, +// matching the performance of the fp32 template-specialized path. +template void qk_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int, float); +template void qk_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int, float); + +template void pv_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int); + void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { - qk_gemm_bf16s_avx512(S, Q, K, m, n, d, scale); + switch (d) + { + case 64: + qk_gemm_bf16s_avx512bf16_kernel_t<64>(S, Q, K, m, n, scale); + break; + case 128: + qk_gemm_bf16s_avx512bf16_kernel_t<128>(S, Q, K, m, n, scale); + break; + case 256: + qk_gemm_bf16s_avx512bf16_kernel_t<256>(S, Q, K, m, n, scale); + break; + case 512: + qk_gemm_bf16s_avx512bf16_kernel_t<512>(S, Q, K, m, n, scale); + break; + case 1024: + qk_gemm_bf16s_avx512bf16_kernel_t<1024>(S, Q, K, m, n, scale); + break; + case 4096: + qk_gemm_bf16s_avx512bf16_kernel_t<4096>(S, Q, K, m, n, scale); + break; + default: + qk_gemm_bf16s_avx512bf16_kernel(S, Q, K, m, n, d, scale); + break; + } } void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d) { - pv_gemm_bf16s_avx512(O, P, V, m, n, d); + switch (d) + { + case 64: + pv_gemm_bf16s_avx512bf16_kernel_t<64>(O, P, V, m, n); + break; + case 128: + pv_gemm_bf16s_avx512bf16_kernel_t<128>(O, P, V, m, n); + break; + case 256: + pv_gemm_bf16s_avx512bf16_kernel_t<256>(O, P, V, m, n); + break; + case 512: + pv_gemm_bf16s_avx512bf16_kernel_t<512>(O, P, V, m, n); + break; + case 1024: + pv_gemm_bf16s_avx512bf16_kernel_t<1024>(O, P, V, m, n); + break; + case 4096: + pv_gemm_bf16s_avx512bf16_kernel_t<4096>(O, P, V, m, n); + break; + default: + pv_gemm_bf16s_avx512bf16_kernel(O, P, V, m, n, d); + break; + } } #endif // __AVX512BF16__ diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h index f37c27591ee2..b26a70706ae3 100644 --- a/src/layer/x86/sdpa_x86_bf16s.h +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -1825,6 +1825,937 @@ static void pv_gemm_bf16s_scalar(float* O, const float* P, const unsigned short* } } +#if __AVX512F__ && __AVX512BF16__ + +static void decode_qk_dot_bf16s_avx512bf16_kernel(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) +{ + int n_end = n_start + block_n; + int j = n_start; + for (; j + 1 < n_end; j += 2) + { + __m512 acc0 = _mm512_setzero_ps(); + __m512 acc1 = _mm512_setzero_ps(); + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q + k)); + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(K + j * d + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(K + (j + 1) * d + k)); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)qv, (__m512bh)kv0); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)qv, (__m512bh)kv1); + } + if (k + 15 < d) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q + k)), 0); + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + j * d + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + (j + 1) * d + k)), 0); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)qv, (__m512bh)kv0); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)qv, (__m512bh)kv1); + k += 16; + } + float tail_sum0 = 0, tail_sum1 = 0; + for (; k < d; k++) + { + float qv = q[k]; + tail_sum0 += qv * bfloat16_to_float32(K[j * d + k]); + tail_sum1 += qv * bfloat16_to_float32(K[(j + 1) * d + k]); + } + s[j] = (_mm512_comp_reduce_add_ps(acc0) + tail_sum0) * scale; + s[j + 1] = (_mm512_comp_reduce_add_ps(acc1) + tail_sum1) * scale; + } + for (; j < n_end; j++) + { + __m512 acc = _mm512_setzero_ps(); + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q + k)); + __m512i kv = _mm512_loadu_si512((const __m512i*)(K + j * d + k)); + acc = _mm512_dpbf16_ps(acc, (__m512bh)qv, (__m512bh)kv); + } + if (k + 15 < d) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q + k)), 0); + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(K + j * d + k)), 0); + acc = _mm512_dpbf16_ps(acc, (__m512bh)qv, (__m512bh)kv); + k += 16; + } + float tail_sum = 0; + for (; k < d; k++) + tail_sum += q[k] * bfloat16_to_float32(K[j * d + k]); + s[j] = (_mm512_comp_reduce_add_ps(acc) + tail_sum) * scale; + } +} + +static void decode_pv_gemv_bf16s_avx512bf16_kernel(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) +{ + int n_end = n_start + block_n; + int k = 0; + for (; k + 31 < out_d; k += 32) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + __m512 oval1 = _mm512_loadu_ps(out + k + 16); + for (int j = n_start; j < n_end; j++) + { + __m512 sj = _mm512_set1_ps(s[j]); + __m512i v0 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k))); + __m512i v1 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k + 16))); + oval0 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v0, 16)), oval0); + oval1 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v1, 16)), oval1); + } + _mm512_storeu_ps(out + k, oval0); + _mm512_storeu_ps(out + k + 16, oval1); + } + if (k + 15 < out_d) + { + __m512 oval0 = _mm512_loadu_ps(out + k); + for (int j = n_start; j < n_end; j++) + { + __m512 sj = _mm512_set1_ps(s[j]); + __m512i v0 = _mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i*)(V + j * out_d + k))); + oval0 = _mm512_fmadd_ps(sj, _mm512_castsi512_ps(_mm512_slli_epi32(v0, 16)), oval0); + } + _mm512_storeu_ps(out + k, oval0); + k += 16; + } + for (; k < out_d; k++) + { + for (int j = n_start; j < n_end; j++) + out[k] += s[j] * bfloat16_to_float32(V[j * out_d + k]); + } +} + +static void qk_gemm_bf16s_avx512bf16_kernel(float* S, const float* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * d * sizeof(unsigned short), 64); + + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * d; + unsigned short* qdst = q_bf16 + i * d; + int k = 0; + for (; k + 15 < d; k += 16) + { + __m256i t = float2bfloat_avx512(_mm512_loadu_ps(qptr + k)); + _mm256_storeu_si256((__m256i*)(qdst + k), t); + } + for (; k < d; k++) + qdst[k] = float32_to_bfloat16(qptr[k]); + } + + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + qk_gemm_bf16s_avx512(S + i * n, Q + i * d, K, m - i, n, d, scale); + } + + _mm_free(q_bf16); +} + +static void pv_gemm_bf16s_avx512bf16_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) +{ + unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); + for (int i = 0; i < m; i++) + { + const float* pptr = P + i * n; + unsigned short* pdst = p_bf16 + i * n; + int j = 0; + for (; j + 15 < n; j += 16) + { + __m512 pvec = _mm512_loadu_ps(pptr + j); + __m256i pbf16 = (__m256i)_mm512_cvtneps_pbh(pvec); + _mm256_storeu_si256((__m256i*)(pdst + j), pbf16); + } + for (; j < n; j++) + pdst[j] = float32_to_bfloat16(pptr[j]); + } + + int dd = 0; + for (; dd + 31 < d; dd += 32) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * d + dd; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + 0); + acc[mi][1] = _mm512_loadu_ps(op[mi] + 16); + } + + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)p_ext, (__m512bh)v0_ext); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)p_ext, (__m512bh)v1_ext); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 16, acc[mi][1]); + } + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * d + dd; + float* op1 = O + (i + 1) * d + dd; + __m512 acc00 = _mm512_loadu_ps(op0 + 0); + __m512 acc01 = _mm512_loadu_ps(op0 + 16); + __m512 acc10 = _mm512_loadu_ps(op1 + 0); + __m512 acc11 = _mm512_loadu_ps(op1 + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc00 = _mm512_dpbf16_ps(acc00, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc01 = _mm512_dpbf16_ps(acc01, (__m512bh)p_ext0, (__m512bh)v1_ext); + acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); + acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); + } + _mm512_storeu_ps(op0 + 0, acc00); _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); _mm512_storeu_ps(op1 + 16, acc11); + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + __m512 acc0 = _mm512_loadu_ps(optr + 0); + __m512 acc1 = _mm512_loadu_ps(optr + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * d + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext, (__m512bh)v1_ext); + } + _mm512_storeu_ps(optr + 0, acc0); + _mm512_storeu_ps(optr + 16, acc1); + } + } + + for (; dd + 15 < d; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * d + dd; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)p_ext, (__m512bh)v0_ext); + } + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * d + dd; + float* op1 = O + (i + 1) * d + dd; + __m512 acc0 = _mm512_loadu_ps(op0); + __m512 acc1 = _mm512_loadu_ps(op1); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext1, (__m512bh)v0_ext); + } + _mm512_storeu_ps(op0, acc0); + _mm512_storeu_ps(op1, acc1); + } + for (; i < m; i++) + { + float* optr = O + i * d + dd; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * d + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc = _mm512_dpbf16_ps(acc, (__m512bh)p_ext, (__m512bh)v0_ext); + } + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < d; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * d + dd; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += bfloat16_to_float32(p_bf16[i * n + j]) * bfloat16_to_float32(V[j * d + dd]); + optr[0] = acc; + } + } + + _mm_free(p_bf16); +} + +template +static void qk_gemm_bf16s_avx512bf16_kernel_t(float* S, const float* Q, const unsigned short* K, + int m, int n, float scale) +{ + unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * D * sizeof(unsigned short), 64); + + for (int i = 0; i < m; i++) + { + const float* qptr = Q + i * D; + unsigned short* qdst = q_bf16 + i * D; + int k = 0; + for (; k + 15 < D; k += 16) + { + __m256i t = float2bfloat_avx512(_mm512_loadu_ps(qptr + k)); + _mm256_storeu_si256((__m256i*)(qdst + k), t); + } + for (; k < D; k++) + qdst[k] = float32_to_bfloat16(qptr[k]); + } + + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * D; + const unsigned short* k1 = K + (j + 1) * D; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < D) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * D; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < D) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * D; + const unsigned short* k1 = K + (j + 1) * D; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < D) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * D; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < D; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(q_bf16 + (i + mi) * D + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < D) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(q_bf16 + (i + mi) * D + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < D; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(q_bf16[(i + mi) * D + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + // scalar tail fallback not available in template + for (int ii = i; ii < m; ii++) + for (int j = 0; j < n; j++) + { + float sum = 0; + for (int k = 0; k < D; k++) + sum += Q[ii * D + k] * bfloat16_to_float32(K[j * D + k]); + S[ii * n + j] = sum * scale; + } + } + + _mm_free(q_bf16); +} + + +template +static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n) +{ + unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); + for (int i = 0; i < m; i++) + { + const float* pptr = P + i * n; + unsigned short* pdst = p_bf16 + i * n; + int j = 0; + for (; j + 15 < n; j += 16) + { + __m512 pvec = _mm512_loadu_ps(pptr + j); + __m256i pbf16 = (__m256i)_mm512_cvtneps_pbh(pvec); + _mm256_storeu_si256((__m256i*)(pdst + j), pbf16); + } + for (; j < n; j++) + pdst[j] = float32_to_bfloat16(pptr[j]); + } + + int dd = 0; + for (; dd + 31 < D; dd += 32) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * D + dd; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_loadu_ps(op[mi] + 0); + acc[mi][1] = _mm512_loadu_ps(op[mi] + 16); + } + + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)p_ext, (__m512bh)v0_ext); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)p_ext, (__m512bh)v1_ext); + } + } + + for (int mi = 0; mi < 4; mi++) + { + _mm512_storeu_ps(op[mi] + 0, acc[mi][0]); + _mm512_storeu_ps(op[mi] + 16, acc[mi][1]); + } + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * D + dd; + float* op1 = O + (i + 1) * D + dd; + __m512 acc00 = _mm512_loadu_ps(op0 + 0); + __m512 acc01 = _mm512_loadu_ps(op0 + 16); + __m512 acc10 = _mm512_loadu_ps(op1 + 0); + __m512 acc11 = _mm512_loadu_ps(op1 + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc00 = _mm512_dpbf16_ps(acc00, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc01 = _mm512_dpbf16_ps(acc01, (__m512bh)p_ext0, (__m512bh)v1_ext); + acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); + acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); + } + _mm512_storeu_ps(op0 + 0, acc00); _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); _mm512_storeu_ps(op1 + 16, acc11); + } + for (; i < m; i++) + { + float* optr = O + i * D + dd; + __m512 acc0 = _mm512_loadu_ps(optr + 0); + __m512 acc1 = _mm512_loadu_ps(optr + 16); + for (int j = 0; j < n; j++) + { + const unsigned short* vptr = V + j * D + dd; + __m256i v0 = _mm256_loadu_si256((const __m256i*)(vptr + 0)); + __m256i v1 = _mm256_loadu_si256((const __m256i*)(vptr + 16)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + __m512i v1_ext = _mm512_cvtepu16_epi32(v1); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext, (__m512bh)v1_ext); + } + _mm512_storeu_ps(optr + 0, acc0); + _mm512_storeu_ps(optr + 16, acc1); + } + } + + for (; dd + 15 < D; dd += 16) + { + int i = 0; + for (; i + 4 <= m; i += 4) + { + float* op[4]; + for (int mi = 0; mi < 4; mi++) + op[mi] = O + (i + mi) * D + dd; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_loadu_ps(op[mi]); + + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + for (int mi = 0; mi < 4; mi++) + { + unsigned short p_val = p_bf16[(i + mi) * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)p_ext, (__m512bh)v0_ext); + } + } + for (int mi = 0; mi < 4; mi++) + _mm512_storeu_ps(op[mi], acc[mi]); + } + for (; i + 2 <= m; i += 2) + { + float* op0 = O + i * D + dd; + float* op1 = O + (i + 1) * D + dd; + __m512 acc0 = _mm512_loadu_ps(op0); + __m512 acc1 = _mm512_loadu_ps(op1); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p0 = p_bf16[i * n + j]; + unsigned short p1 = p_bf16[(i + 1) * n + j]; + __m512i p_ext0 = _mm512_set1_epi32((unsigned int)p0); + __m512i p_ext1 = _mm512_set1_epi32((unsigned int)p1); + acc0 = _mm512_dpbf16_ps(acc0, (__m512bh)p_ext0, (__m512bh)v0_ext); + acc1 = _mm512_dpbf16_ps(acc1, (__m512bh)p_ext1, (__m512bh)v0_ext); + } + _mm512_storeu_ps(op0, acc0); + _mm512_storeu_ps(op1, acc1); + } + for (; i < m; i++) + { + float* optr = O + i * D + dd; + __m512 acc = _mm512_loadu_ps(optr); + for (int j = 0; j < n; j++) + { + __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); + __m512i v0_ext = _mm512_cvtepu16_epi32(v0); + unsigned short p_val = p_bf16[i * n + j]; + __m512i p_ext = _mm512_set1_epi32((unsigned int)p_val); + acc = _mm512_dpbf16_ps(acc, (__m512bh)p_ext, (__m512bh)v0_ext); + } + _mm512_storeu_ps(optr, acc); + } + } + + for (; dd < D; dd++) + { + for (int i = 0; i < m; i++) + { + float* optr = O + i * D + dd; + float acc = optr[0]; + for (int j = 0; j < n; j++) + acc += bfloat16_to_float32(p_bf16[i * n + j]) * bfloat16_to_float32(V[j * D + dd]); + optr[0] = acc; + } + } + + _mm_free(p_bf16); +} + +#endif // __AVX512F__ && __AVX512BF16__ + + + // --------------------------------------------------------------------------- // Dispatch wrappers @@ -1931,5 +2862,4 @@ static inline void sdpa_decode_reduce_bf16s_dispatch( { sdpa_decode_reduce_bf16s(out, out_d, partials, num_chunks, partial_stride); } - #endif // SDPA_X86_BF16S_H From aab48b5857f1f4c90b1e01b02a5e8af1c0182acd Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 17:09:58 +0800 Subject: [PATCH 31/53] perf: vectorize int8 decode/prefill softmax+mask, refactor sdpa_decode helpers, add missing D values to dispatch, add FP32 group_parallel decode path - Extract decode_mask_vec / decode_max_vec / decode_exp_sum_vec shared helpers with AVX512/AVX/SSE2 vectorization, replacing scalar loops in int8 decode path - Refactor sdpa_decode and sdpa_decode_chunk to use shared helpers, eliminating ~150 lines of duplicated mask/softmax code per function - Add D=768,1536,3072 to qk_gemm_dispatch and pv_gemm_dispatch for common LLM head dimensions (LLaMA-13B, Qwen-72B, etc.) - Add group_parallel path to FP32 decode: when num_group < num_threads, parallelize per-head instead of per-group to improve MQA/GQA utilization --- src/layer/x86/sdpa_x86.cpp | 548 ++++++++++++++++++------------------- 1 file changed, 274 insertions(+), 274 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index def40e2d4531..24651d519cf5 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -606,6 +606,140 @@ static inline void decode_qk_dot(float* s, const float* q, const float* K, int n #endif } +static inline void decode_mask_vec(float* s, const float* mask, int block_n) +{ +#if __AVX512F__ + int j = 0; + for (; j + 15 < block_n; j += 16) + _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + j))); + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + _mm512_mask_storeu_ps(s + j, mask_n, + _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + j))); + } +#elif __AVX__ + int j = 0; + for (; j + 7 < block_n; j += 8) + _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + j))); + for (; j < block_n; j++) + s[j] += mask[j]; +#elif __SSE2__ + int j = 0; + for (; j + 3 < block_n; j += 4) + _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + j))); + for (; j < block_n; j++) + s[j] += mask[j]; +#else + for (int j = 0; j < block_n; j++) + s[j] += mask[j]; +#endif +} + +static inline float decode_max_vec(const float* s, int block_n) +{ +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s + j)); + } + return _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); + float m = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(-FLT_MAX); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); + float m = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#else + float m = -FLT_MAX; + for (int j = 0; j < block_n; j++) + m = std::max(m, s[j]); + return m; +#endif +} + +static inline void decode_exp_sum_vec(float* s, int block_n, float new_m, float* l_add) +{ +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(new_m); + __m512 vsum = _mm512_setzero_ps(); + int j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); + _mm512_storeu_ps(s + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); + _mm512_mask_storeu_ps(s + j, mask_n, pvec); + vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); + } + *l_add = _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(new_m); + __m256 vsum = _mm256_setzero_ps(); + int j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); + _mm256_storeu_ps(s + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + float l = _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(new_m); + __m128 vsum = _mm_setzero_ps(); + int j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); + _mm_storeu_ps(s + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + float l = _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#else + float l = 0.f; + for (int j = 0; j < block_n; j++) + { + s[j] = expf(s[j] - new_m); + l += s[j]; + } + *l_add = l; +#endif +} + static void sdpa_decode(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) @@ -633,68 +767,9 @@ static void sdpa_decode(float* out, const float* q, decode_qk_dot(s, q, K, n_start, block_n, d, scale); if (mask) - { -#if __AVX512F__ - int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); - } -#elif __AVX__ - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; -#elif __SSE2__ - int j = 0; - for (; j + 3 < block_n; j += 4) - _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; -#else - for (int j = 0; j < block_n; j++) - s[j] += mask[n_start + j]; -#endif - } - -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); - } - float tile_m = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); - float tile_m = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#else - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#endif + decode_mask_vec(s, mask + n_start, block_n); + float tile_m = decode_max_vec(s, block_n); float new_m = std::max(m, tile_m); if (m != new_m) { @@ -703,67 +778,9 @@ static void sdpa_decode(float* out, const float* q, vec_scale(out, scale_factor, out_d); } -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); - } - l += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); l += l_add; -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(new_m); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); - _mm_storeu_ps(s + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#else - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#endif decode_pv_gemv(out, s, V, n_start, block_n, out_d); @@ -802,68 +819,9 @@ static void sdpa_decode_chunk( decode_qk_dot(s, q, K, n, block_n, d, scale); if (mask) - { -#if __AVX512F__ - int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n + j))); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n + j))); - } -#elif __AVX__ - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n + j))); - for (; j < block_n; j++) - s[j] += mask[n + j]; -#elif __SSE2__ - int j = 0; - for (; j + 3 < block_n; j += 4) - _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n + j))); - for (; j < block_n; j++) - s[j] += mask[n + j]; -#else - for (int j = 0; j < block_n; j++) - s[j] += mask[n + j]; -#endif - } - -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); - } - float tile_m = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); - float tile_m = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#else - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#endif + decode_mask_vec(s, mask + n, block_n); + float tile_m = decode_max_vec(s, block_n); float new_m = std::max(m, tile_m); if (m != new_m) { @@ -872,67 +830,9 @@ static void sdpa_decode_chunk( vec_scale(out, scale_factor, out_d); } -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); - } - l += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); l += l_add; -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(new_m); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); - _mm_storeu_ps(s + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#else - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#endif decode_pv_gemv(out, s, V, n, block_n, out_d); @@ -3393,6 +3293,45 @@ static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, #elif __SSE2__ qk_gemm_specialized_sse2<4096>(S, Q, K, m, n, scale); return; +#endif + } + if (d == 768) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<768>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 1536) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(S, Q, K, m, n, scale); + return; +#endif + } + if (d == 3072) + { +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(S, Q, K, m, n, scale); + return; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(S, Q, K, m, n, scale); + return; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(S, Q, K, m, n, scale); + return; #endif } @@ -3499,6 +3438,45 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, #elif __SSE2__ pv_gemm_sse2<2, 4096>(O, P, V, m, n); return; +#endif + } + if (d == 768) + { +#if __AVX512F__ + pv_gemm_avx512<2, 768>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 768>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 768>(O, P, V, m, n); + return; +#endif + } + if (d == 1536) + { +#if __AVX512F__ + pv_gemm_avx512<2, 1536>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 1536>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 1536>(O, P, V, m, n); + return; +#endif + } + if (d == 3072) + { +#if __AVX512F__ + pv_gemm_avx512<2, 3072>(O, P, V, m, n); + return; +#elif __AVX__ + pv_gemm_avx<2, 3072>(O, P, V, m, n); + return; +#elif __SSE2__ + pv_gemm_sse2<2, 3072>(O, P, V, m, n); + return; #endif } @@ -3794,14 +3772,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to n_start, block_n, embed_dim, _scale); if (mask_ptr) - { - for (int j = 0; j < block_n; j++) - s[j] += mask_ptr[n_start + j]; - } + decode_mask_vec(s, mask_ptr + n_start, block_n); - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); + float tile_m = decode_max_vec(s, block_n); float new_m = std::max(m, tile_m); if (m != new_m) @@ -3811,12 +3784,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to vec_scale_dispatch(out, scale_factor, out_embed_dim); } - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); l += l_add; decode_pv_gemv_int8(out, s, @@ -3885,14 +3854,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to n_start, block_n, embed_dim, _scale); if (mask_ptr) - { - for (int j = 0; j < block_n; j++) - s[j] += mask_ptr[n_start + j]; - } + decode_mask_vec(s, mask_ptr + n_start, block_n); - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); + float tile_m = decode_max_vec(s, block_n); float new_m = std::max(m, tile_m); if (m != new_m) @@ -3902,12 +3866,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to vec_scale_dispatch(out, scale_factor, out_embed_dim); } - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); l += l_add; decode_pv_gemv_int8(out, s, @@ -4284,15 +4244,55 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to else { // Decode path: fused GEMV kernel for single-query attention - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); - for (int hq = 0; hq < num_heads_per_group; hq++) + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - int q = g * num_heads_per_group + hq; + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); const Mat query_head = query_ref.channel(q); Mat top_blob_head = top_blob.channel(q); From 285dee774d0b29447660125dbdb0e75b87a25e1d Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 17:20:16 +0800 Subject: [PATCH 32/53] perf: vectorize int8 prefill softmax, use decode_mask_vec consistently, hoist Q copy in large_dim path - Apply vectorized max reduction + exp to int8 prefill path, previously scalar - Replace all remaining inline mask code in prefill with decode_mask_vec() - Move Q copy outside N-tile loop in large_dim prefill path: copies per M-tile reduce from (dst_seqlen/BLOCK_N) * num_heads to num_heads Key improvements: prefill e4096_s64_int8: -3.6% decode e128_p2048_int8: -10.4% decode e512_p1024_fp16ps: -8.1% --- src/layer/x86/sdpa_x86.cpp | 242 ++++++++++++++++++++----------------- 1 file changed, 133 insertions(+), 109 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 24651d519cf5..f57a59279264 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3989,19 +3989,40 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); if (attn_mask) + decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); + if (j < block_n) { - const float* mptr = mask_head.row(m_start + i) + n_start; - for (int j = 0; j < block_n; j++) - { - s_row[j] += mptr[j]; - } + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); } - + float m_new = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); +#else float m_new = m_vec[i]; for (int j = 0; j < block_n; j++) - { m_new = std::max(m_new, s_row[j]); - } +#endif float scale_factor = expf(m_vec[i] - m_new); float l_new = l_vec[i] * scale_factor; @@ -4012,11 +4033,63 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to optr[k] *= scale_factor; } +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); + _mm512_storeu_ps(p_row + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); + _mm512_mask_storeu_ps(p_row + j, mask, pvec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); + } + l_new += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); + _mm256_storeu_ps(p_row + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + l_new += _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); + _mm_storeu_ps(p_row + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + l_new += _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } +#else for (int j = 0; j < block_n; j++) { p_row[j] = expf(s_row[j] - m_new); l_new += p_row[j]; } +#endif m_vec[i] = m_new; l_vec[i] = l_new; @@ -4450,35 +4523,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; float* sptr = s_head + i * block_n; - int j = 0; -#if __AVX512F__ - for (; j + 15 < block_n; j += 16) - { - _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); - } - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); - _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); - j = block_n; - } -#elif __AVX__ - for (; j + 7 < block_n; j += 8) - { - _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); - } -#elif __SSE2__ - for (; j + 3 < block_n; j += 4) - { - _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); - } -#endif - for (; j < block_n; j++) - { - sptr[j] += mptr[j]; - } + decode_mask_vec(sptr, mptr, block_n); } } @@ -4518,6 +4563,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } else { + // large_dim: copy Q per head once, then loop over N-tiles for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; @@ -4529,94 +4575,72 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); } - float* s_head = s_ptr; + for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) + { + int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; + int block_n2 = n_end2 - n_start2; + + float* s_head = s_ptr; #if NCNN_BF16 - if (use_bf16_path) - { - qk_gemm_bf16s_dispatch(s_head, - q_dst, - key_head.row(n_start), - block_m, block_n, embed_dim, _scale); - } - else + if (use_bf16_path) + { + qk_gemm_bf16s_dispatch(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else #endif - { - qk_gemm_dispatch(s_head, - q_dst, - key_head.row(n_start), - block_m, block_n, embed_dim, _scale); - } + { + qk_gemm_dispatch(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } - if (attn_mask && mask_data[hq]) - { - for (int i = 0; i < block_m; i++) + if (attn_mask && mask_data[hq]) { - const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; - float* sptr = s_head + i * block_n; - int j = 0; -#if __AVX512F__ - for (; j + 15 < block_n; j += 16) - { - _mm512_storeu_ps(sptr + j, _mm512_add_ps(_mm512_loadu_ps(sptr + j), _mm512_loadu_ps(mptr + j))); - } - if (j < block_n) + for (int i = 0; i < block_m; i++) { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 _s = _mm512_maskz_loadu_ps(mask, sptr + j); - __m512 _m = _mm512_maskz_loadu_ps(mask, mptr + j); - _mm512_mask_storeu_ps(sptr + j, mask, _mm512_add_ps(_s, _m)); - j = block_n; - } -#elif __AVX__ - for (; j + 7 < block_n; j += 8) - { - _mm256_storeu_ps(sptr + j, _mm256_add_ps(_mm256_loadu_ps(sptr + j), _mm256_loadu_ps(mptr + j))); - } -#elif __SSE2__ - for (; j + 3 < block_n; j += 4) - { - _mm_storeu_ps(sptr + j, _mm_add_ps(_mm_loadu_ps(sptr + j), _mm_loadu_ps(mptr + j))); - } -#endif - for (; j < block_n; j++) - { - sptr[j] += mptr[j]; + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); } } - } - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); - float m_old[BLOCK_M]; - float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; + } + softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) + for (int i = 0; i < block_m; i++) { - vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + if (m_old[i] != m_vec[i]) + { + vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } } - } #if NCNN_BF16 - if (use_bf16_path) - { - pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start), - block_m, block_n, out_embed_dim); - } - else + if (use_bf16_path) + { + pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + else #endif - { - pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start), - block_m, block_n, out_embed_dim); + { + pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } } } } From 2dfb56653c7676da85dd93ccc17f5b68a19b42c4 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 17:45:38 +0800 Subject: [PATCH 33/53] perf: vectorize int8 prefill scale ops, fix pv_gemm register pressure, add V prefetch - Replace int8 prefill scalar scale_factor application with vec_scale_dispatch - Replace int8 prefill scalar output normalization with inline AVX512/AVX/SSE2 - Replace int8 decode scalar output normalization with memcpy+vec_scale_dispatch - Replace int8 prefill scalar zero-init with vec_zero_dispatch - Fix pv_gemm dispatch: use D_UNROLL=128 (AVX512) / 64 (AVX) / 32 (SSE2) instead of matching full D dimension. Prevents massive register spilling for large D (VEC_PER_UNROLL was 256 for D=4096 on AVX512 - far beyond 32 ZMM regs) - Add V matrix software prefetch in pv_gemm_avx512 inner j loop for d>=512 (hardware prefetcher cannot track 16KB+ strides) --- src/layer/x86/sdpa_x86.cpp | 91 ++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 38 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index f57a59279264..9f4745e54002 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1608,6 +1608,8 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { + if (d >= 512 && j + 4 < n) + _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); @@ -1636,6 +1638,8 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { + if (d >= 512 && j + 4 < n) + _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 pvec = _mm512_set1_ps(pptr[j]); for (int vi = 0; vi < VEC_PER_UNROLL; vi++) acc[vi] = _mm512_fmadd_ps(pvec, _mm512_loadu_ps(V + j * d + dd + vi * 16), acc[vi]); @@ -3378,104 +3382,104 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, if (d == 256) { #if __AVX512F__ - pv_gemm_avx512<2, 256>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 256); return; #elif __AVX__ - pv_gemm_avx<2, 256>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 256); return; #elif __SSE2__ - pv_gemm_sse2<2, 256>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 256); return; #endif } if (d == 512) { #if __AVX512F__ - pv_gemm_avx512<2, 512>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 512); return; #elif __AVX__ - pv_gemm_avx<2, 512>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 512); return; #elif __SSE2__ - pv_gemm_sse2<2, 512>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 512); return; #endif } if (d == 1024) { #if __AVX512F__ - pv_gemm_avx512<2, 1024>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 1024); return; #elif __AVX__ - pv_gemm_avx<2, 1024>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 1024); return; #elif __SSE2__ - pv_gemm_sse2<2, 1024>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 1024); return; #endif } if (d == 2048) { #if __AVX512F__ - pv_gemm_avx512<2, 2048>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 2048); return; #elif __AVX__ - pv_gemm_avx<2, 2048>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 2048); return; #elif __SSE2__ - pv_gemm_sse2<2, 2048>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 2048); return; #endif } if (d == 4096) { #if __AVX512F__ - pv_gemm_avx512<2, 4096>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 4096); return; #elif __AVX__ - pv_gemm_avx<2, 4096>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 4096); return; #elif __SSE2__ - pv_gemm_sse2<2, 4096>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 4096); return; #endif } if (d == 768) { #if __AVX512F__ - pv_gemm_avx512<2, 768>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 768); return; #elif __AVX__ - pv_gemm_avx<2, 768>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 768); return; #elif __SSE2__ - pv_gemm_sse2<2, 768>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 768); return; #endif } if (d == 1536) { #if __AVX512F__ - pv_gemm_avx512<2, 1536>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 1536); return; #elif __AVX__ - pv_gemm_avx<2, 1536>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 1536); return; #elif __SSE2__ - pv_gemm_sse2<2, 1536>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 1536); return; #endif } if (d == 3072) { #if __AVX512F__ - pv_gemm_avx512<2, 3072>(O, P, V, m, n); + pv_gemm_avx512<2, 128>(O, P, V, m, n, 3072); return; #elif __AVX__ - pv_gemm_avx<2, 3072>(O, P, V, m, n); + pv_gemm_avx<2, 64>(O, P, V, m, n, 3072); return; #elif __SSE2__ - pv_gemm_sse2<2, 3072>(O, P, V, m, n); + pv_gemm_sse2<2, 32>(O, P, V, m, n, 3072); return; #endif } @@ -3798,8 +3802,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* outptr = top_blob_head.row(0); float inv_l = 1.f / l; - for (int k = 0; k < out_embed_dim; k++) - outptr[k] = out[k] * inv_l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale_dispatch(outptr, inv_l, out_embed_dim); } } } @@ -3880,8 +3884,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* outptr = top_blob_head.row(0); float inv_l = 1.f / l; - for (int k = 0; k < out_embed_dim; k++) - outptr[k] = out[k] * inv_l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale_dispatch(outptr, inv_l, out_embed_dim); } } @@ -3945,10 +3949,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int i = 0; i < block_m; i++) { float* optr = o_accum_head.row(i); - for (int k = 0; k < out_embed_dim; k++) - { - optr[k] = 0.f; - } + vec_zero_dispatch(optr, out_embed_dim); } float m_vec[BLOCK_M]; @@ -4028,10 +4029,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float l_new = l_vec[i] * scale_factor; float* optr = o_accum_head.row(i); - for (int k = 0; k < out_embed_dim; k++) - { - optr[k] *= scale_factor; - } + vec_scale_dispatch(optr, scale_factor, out_embed_dim); #if __AVX512F__ __m512 vm_new = _mm512_set1_ps(m_new); @@ -4115,10 +4113,27 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* optr = o_accum_head.row(i); float* outptr = top_blob_head.row(m_start + i); float inv_l = 1.f / l_vec[i]; - for (int k = 0; k < out_embed_dim; k++) + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); + if (k < out_embed_dim) { - outptr[k] = optr[k] * inv_l; + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); } +#elif __AVX__ + __m256 vinv_l256 = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); +#elif __SSE2__ + __m128 vinv_l128 = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); +#endif + for (; k < out_embed_dim; k++) + outptr[k] = optr[k] * inv_l; } } } From 397412d51831f35657171ebd9bc9bd99a7be3786 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 18:00:08 +0800 Subject: [PATCH 34/53] perf: add K-row software prefetch in qk_gemm_specialized_tiled_avx512 for D>=512 - Prefetch next 4 K rows at start of each j+=4 block when D >= 512 - Applied to both M_BLOCK2 and M_BLOCK1 loops - D>=512 means rows are at least 2KB apart; hardware prefetcher may not track such large strides reliably --- src/layer/x86/sdpa_x86.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 9f4745e54002..1a96ef311a00 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1292,6 +1292,13 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { + if (D >= 512 && j + 8 <= n) + { + _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 6) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 7) * D), _MM_HINT_T1); + } const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; const float* k2 = K + (j + 2) * D; @@ -1393,6 +1400,13 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { + if (D >= 512 && j + 8 <= n) + { + _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 6) * D), _MM_HINT_T1); + _mm_prefetch((const char*)(K + (j + 7) * D), _MM_HINT_T1); + } const float* k0 = K + (j + 0) * D; const float* k1 = K + (j + 1) * D; const float* k2 = K + (j + 2) * D; From f0737a864d953f8ac34399627d1cdc972b2b608b Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 18:13:30 +0800 Subject: [PATCH 35/53] perf: add group_parallel path to BF16 decode for MQA/GQA thread utilization When num_group < num_threads (e.g. MQA with num_group=1, num_heads=32), parallelize per-head instead of per-group to use all threads --- src/layer/x86/sdpa_x86.cpp | 52 +++++++++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 6 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 1a96ef311a00..0c2ef4b3a965 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -4240,15 +4240,55 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } else { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); - for (int hq = 0; hq < num_heads_per_group; hq++) + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode_bf16s_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - int q = g * num_heads_per_group + hq; + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); const Mat query_head = query_ref.channel(q); Mat top_blob_head = top_blob.channel(q); From b448d0d09ab306d542bbf5def94212bda6f557a5 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 21:20:08 +0800 Subject: [PATCH 36/53] refactor: remove 4 *_dispatch thin wrappers, call underlying functions directly Deleted vec_scale_dispatch, vec_zero_dispatch, softmax_tile_dispatch, sdpa_decode_dispatch - each was a simple 1-line forward to the real function with no additional logic. Replaced all 16 call sites with direct calls to vec_scale/vec_zero/softmax_tile/sdpa_decode. --- src/layer/x86/sdpa_x86.cpp | 53 +++++++++++--------------------------- 1 file changed, 15 insertions(+), 38 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 0c2ef4b3a965..8fb133e6f0be 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3509,29 +3509,6 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, #endif } -static inline void vec_scale_dispatch(float* x, float s, int n) -{ - vec_scale(x, s, n); -} - -static inline void vec_zero_dispatch(float* x, int n) -{ - vec_zero(x, n); -} - -static inline void softmax_tile_dispatch(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) -{ - softmax_tile(P, S, m_vec, l_vec, scale_out, m, n); -} - -static inline void sdpa_decode_dispatch(float* out, const float* q, - const float* K, const float* V, const float* mask, - int n, int d, int out_d, float scale) -{ - sdpa_decode(out, q, K, V, mask, n, d, out_d, scale); -} - // Timing instrumentation removed int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const @@ -3758,7 +3735,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* s = s_vec.channel(get_omp_thread_num()); float* out = o_accum.channel(get_omp_thread_num()); - vec_zero_dispatch(out, out_embed_dim); + vec_zero(out, out_embed_dim); const float* mask_ptr = nullptr; if (attn_mask) @@ -3799,7 +3776,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { float scale_factor = expf(m - new_m); l *= scale_factor; - vec_scale_dispatch(out, scale_factor, out_embed_dim); + vec_scale(out, scale_factor, out_embed_dim); } float l_add; @@ -3817,7 +3794,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* outptr = top_blob_head.row(0); float inv_l = 1.f / l; memcpy(outptr, out, out_embed_dim * sizeof(float)); - vec_scale_dispatch(outptr, inv_l, out_embed_dim); + vec_scale(outptr, inv_l, out_embed_dim); } } } @@ -3840,7 +3817,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* s = s_vec.channel(get_omp_thread_num()); float* out = o_accum.channel(get_omp_thread_num()); - vec_zero_dispatch(out, out_embed_dim); + vec_zero(out, out_embed_dim); const float* mask_ptr = nullptr; if (attn_mask) @@ -3881,7 +3858,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { float scale_factor = expf(m - new_m); l *= scale_factor; - vec_scale_dispatch(out, scale_factor, out_embed_dim); + vec_scale(out, scale_factor, out_embed_dim); } float l_add; @@ -3899,7 +3876,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float* outptr = top_blob_head.row(0); float inv_l = 1.f / l; memcpy(outptr, out, out_embed_dim * sizeof(float)); - vec_scale_dispatch(outptr, inv_l, out_embed_dim); + vec_scale(outptr, inv_l, out_embed_dim); } } @@ -3963,7 +3940,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to for (int i = 0; i < block_m; i++) { float* optr = o_accum_head.row(i); - vec_zero_dispatch(optr, out_embed_dim); + vec_zero(optr, out_embed_dim); } float m_vec[BLOCK_M]; @@ -4043,7 +4020,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to float l_new = l_vec[i] * scale_factor; float* optr = o_accum_head.row(i); - vec_scale_dispatch(optr, scale_factor, out_embed_dim); + vec_scale(optr, scale_factor, out_embed_dim); #if __AVX512F__ __m512 vm_new = _mm512_set1_ps(m_new); @@ -4423,7 +4400,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to mask_ptr = mask_head.row(0); } - sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); } } } @@ -4459,7 +4436,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to mask_ptr = mask_head.row(0); } - sdpa_decode_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); } } } @@ -4552,7 +4529,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } float* o_ptr = o_accum_thread.row(hq * block_m); - vec_zero_dispatch(o_ptr, out_embed_dim * block_m); + vec_zero(o_ptr, out_embed_dim * block_m); } // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group @@ -4606,13 +4583,13 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { m_old[i] = m_vec[i]; } - softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) { - vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } } @@ -4688,13 +4665,13 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { m_old[i] = m_vec[i]; } - softmax_tile_dispatch(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) { - vec_scale_dispatch(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } From 7677339589fadd3ebb4951d47db2db209a79c782 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 21:30:48 +0800 Subject: [PATCH 37/53] refactor: extract FP32/BF16 prefill logic from forward() into sdpa_forward_prefill() forward() was 1234 lines; prefill path (334 lines) is now a standalone static helper with explicit parameter list. forward() calls it via return sdpa_forward_prefill(query_ref, ..., use_bf16_path); No functional change. Benefits: forward() shrinks ~334 lines, prefill logic is independently testable, parameter dependencies are explicit. --- src/layer/x86/sdpa_x86.cpp | 749 +++++++++++++++++++------------------ 1 file changed, 389 insertions(+), 360 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 8fb133e6f0be..f41b04208e9a 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3511,6 +3511,325 @@ static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, // Timing instrumentation removed +static int sdpa_forward_prefill( + const Mat& query_ref, + const Mat& attn_mask_ref, + Mat& key, + Mat& value, + Mat& top_blob, + std::vector& top_blobs, + const Option& opt, + int embed_dim, + int src_seqlen, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int kv_cache, + int attn_mask, + bool use_bf16_path) +{ + const int BLOCK_M = 64; + const int BLOCK_N = 128; + Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + const bool large_dim = embed_dim > 512; + Mat q_batch(embed_dim, large_dim ? BLOCK_M : BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + + if (s_vec.empty() || o_accum.empty() || q_batch.empty()) + return -100; + + int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; + + // Per-head per-M-tile softmax state for cross-N-tile accumulation + Mat m_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + Mat l_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); + + if (m_state.empty() || l_state.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int idx = 0; idx < num_group * num_m_tiles; idx++) + { + int g = idx / num_m_tiles; + int m_tile = idx % num_m_tiles; + int m_start = m_tile * BLOCK_M; + int block_m = m_start + BLOCK_M < src_seqlen ? BLOCK_M : src_seqlen - m_start; + + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); + Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); + Mat q_batch_thread = q_batch.channel(get_omp_thread_num()); + + // Pre-resolve mask pointers for all heads in this group + const float* mask_data[num_heads_per_group]; + int mask_stride[num_heads_per_group]; + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + mask_data[hq] = nullptr; + mask_stride[hq] = 0; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mh = (maskm.dims == 3 && maskm.c > 1) ? maskm.channel(q) + : (maskm.dims == 3 ? maskm.channel(0) : maskm); + mask_data[hq] = mh; + mask_stride[hq] = mh.w; + } + } + + Mat m_state_tile = m_state.channel(idx); + Mat l_state_tile = l_state.channel(idx); + + if (!large_dim) + { + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + float* q_dst = q_batch_thread.row(hq * block_m); + for (int i = 0; i < block_m; i++) + { + memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); + } + } + } + + // Initialize softmax state and zero output accumulator for all Q heads in this group + for (int hq = 0; hq < num_heads_per_group; hq++) + { + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + float* o_ptr = o_accum_thread.row(hq * block_m); + vec_zero(o_ptr, out_embed_dim * block_m); + } + + // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + float* s_ptr = s_vec_thread.row(0); + + if (!large_dim) + { +#if NCNN_BF16 + if (use_bf16_path) + { + qk_gemm_bf16s_dispatch(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } + else +#endif + { + qk_gemm_dispatch(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + float* s_head = s_ptr + hq * block_m * block_n; + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; + float* sptr = s_head + i * block_n; + decode_mask_vec(sptr, mptr, block_n); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; + } + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (m_old[i] != m_vec[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + } + +#if NCNN_BF16 + if (use_bf16_path) + { + pv_gemm_bf16s_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } + else +#endif + { + pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } + } + else + { + // large_dim: copy Q per head once, then loop over N-tiles + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + + float* q_dst = q_batch_thread.row(0); + for (int i = 0; i < block_m; i++) + { + memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); + } + + for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) + { + int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; + int block_n2 = n_end2 - n_start2; + + float* s_head = s_ptr; + +#if NCNN_BF16 + if (use_bf16_path) + { + qk_gemm_bf16s_dispatch(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else +#endif + { + qk_gemm_dispatch(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_old[i] = m_vec[i]; + } + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (m_old[i] != m_vec[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + +#if NCNN_BF16 + if (use_bf16_path) + { + pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + else +#endif + { + pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + } + } + } + } + + // Normalize all Q heads for this M tile and write back to top_blob + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + Mat top_blob_head = top_blob.channel(q); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + for (int i = 0; i < block_m; i++) + { + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + { + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, o_ptr + i * out_embed_dim + k), vinv_l)); + k = out_embed_dim; + } +#elif __AVX__ + __m256 vinv_l = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + { + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } +#elif __SSE2__ + __m128 vinv_l = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + { + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); + } +#endif + for (; k < out_embed_dim; k++) + { + outptr[k] = o_ptr[i * out_embed_dim + k] * inv_l; + } + } + } + } + + if (kv_cache) + { + top_blobs[1] = key; + top_blobs[2] = value; + } + + return 0; +} + int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const { Option opt = _opt; @@ -4345,406 +4664,116 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } mask_ptr = mask_head.row(0); } - - float* p = partials.channel(q).row(chunk); - sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, - n_start, n_end, embed_dim, out_embed_dim, _scale); - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - Mat top_blob_head = top_blob.channel(q); - float* outptr = top_blob_head.row(0); - sdpa_decode_reduce(outptr, out_embed_dim, - partials.channel(q), num_kv_chunks, 2 + out_embed_dim); - } - } - else - { - // Decode path: fused GEMV kernel for single-query attention - const bool group_parallel = num_group >= opt.num_threads; - - if (group_parallel) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) - { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } - - if (kv_cache) - { - top_blobs[1] = key; - top_blobs[2] = value; - } - - return 0; - } - - Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - const bool large_dim = embed_dim > 512; - Mat q_batch(embed_dim, large_dim ? BLOCK_M : BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - - if (s_vec.empty() || o_accum.empty() || q_batch.empty()) - return -100; - - int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; - - // Per-head per-M-tile softmax state for cross-N-tile accumulation - Mat m_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); - Mat l_state(BLOCK_M, num_heads_per_group, num_group * num_m_tiles, 4u, opt.workspace_allocator); - - if (m_state.empty() || l_state.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int idx = 0; idx < num_group * num_m_tiles; idx++) - { - int g = idx / num_m_tiles; - int m_tile = idx % num_m_tiles; - int m_start = m_tile * BLOCK_M; - int block_m = m_start + BLOCK_M < src_seqlen ? BLOCK_M : src_seqlen - m_start; - - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - - Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); - Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); - Mat q_batch_thread = q_batch.channel(get_omp_thread_num()); - - // Pre-resolve mask pointers for all heads in this group - const float* mask_data[num_heads_per_group]; - int mask_stride[num_heads_per_group]; - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - mask_data[hq] = nullptr; - mask_stride[hq] = 0; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mh = (maskm.dims == 3 && maskm.c > 1) ? maskm.channel(q) - : (maskm.dims == 3 ? maskm.channel(0) : maskm); - mask_data[hq] = mh; - mask_stride[hq] = mh.w; - } - } - - Mat m_state_tile = m_state.channel(idx); - Mat l_state_tile = l_state.channel(idx); - - if (!large_dim) - { - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); - float* q_dst = q_batch_thread.row(hq * block_m); - for (int i = 0; i < block_m; i++) - { - memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); - } + + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); } - } - // Initialize softmax state and zero output accumulator for all Q heads in this group - for (int hq = 0; hq < num_heads_per_group; hq++) - { - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - for (int i = 0; i < block_m; i++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - m_vec[i] = -FLT_MAX; - l_vec[i] = 0.f; + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); } - - float* o_ptr = o_accum_thread.row(hq * block_m); - vec_zero(o_ptr, out_embed_dim * block_m); } - - // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + else { - int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; - int block_n = n_end - n_start; - - float* s_ptr = s_vec_thread.row(0); + // Decode path: fused GEMV kernel for single-query attention + const bool group_parallel = num_group >= opt.num_threads; - if (!large_dim) + if (group_parallel) { -#if NCNN_BF16 - if (use_bf16_path) - { - qk_gemm_bf16s_dispatch(s_ptr, - q_batch_thread.row(0), - key_head.row(n_start), - block_m * num_heads_per_group, block_n, embed_dim, _scale); - } - else -#endif - { - qk_gemm_dispatch(s_ptr, - q_batch_thread.row(0), - key_head.row(n_start), - block_m * num_heads_per_group, block_n, embed_dim, _scale); - } - - for (int hq = 0; hq < num_heads_per_group; hq++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) { - float* s_head = s_ptr + hq * block_m * block_n; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); - if (attn_mask && mask_data[hq]) + for (int hq = 0; hq < num_heads_per_group; hq++) { - for (int i = 0; i < block_m; i++) - { - const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start; - float* sptr = s_head + i * block_n; - decode_mask_vec(sptr, mptr, block_n); - } - } - - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); - float m_old[BLOCK_M]; - float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) + const float* mask_ptr = nullptr; + if (attn_mask) { - vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); } - } - } -#if NCNN_BF16 - if (use_bf16_path) - { - pv_gemm_bf16s_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); - } - else -#endif - { - pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } } } else { - // large_dim: copy Q per head once, then loop over N-tiles - for (int hq = 0; hq < num_heads_per_group; hq++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - int q = g * num_heads_per_group + hq; + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); - float* q_dst = q_batch_thread.row(0); - for (int i = 0; i < block_m; i++) - { - memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); - } + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; - for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) + const float* mask_ptr = nullptr; + if (attn_mask) { - int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; - int block_n2 = n_end2 - n_start2; - - float* s_head = s_ptr; - -#if NCNN_BF16 - if (use_bf16_path) - { - qk_gemm_bf16s_dispatch(s_head, - q_dst, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); - } - else -#endif - { - qk_gemm_dispatch(s_head, - q_dst, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); - } - - if (attn_mask && mask_data[hq]) - { - for (int i = 0; i < block_m; i++) - { - const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; - float* sptr = s_head + i * block_n2; - decode_mask_vec(sptr, mptr, block_n2); - } - } - - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); - - float m_old[BLOCK_M]; - float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); - - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) - { - vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); - } - } - -#if NCNN_BF16 - if (use_bf16_path) + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) { - pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); } else -#endif { - pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + mask_head = maskm; } + mask_ptr = mask_head.row(0); } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); } } } - // Normalize all Q heads for this M tile and write back to top_blob - for (int hq = 0; hq < num_heads_per_group; hq++) + if (kv_cache) { - int q = g * num_heads_per_group + hq; - Mat top_blob_head = top_blob.channel(q); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); - - for (int i = 0; i < block_m; i++) - { - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - int k = 0; -#if __AVX512F__ - __m512 vinv_l = _mm512_set1_ps(inv_l); - for (; k + 15 < out_embed_dim; k += 16) - { - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); - } - if (k < out_embed_dim) - { - __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, o_ptr + i * out_embed_dim + k), vinv_l)); - k = out_embed_dim; - } -#elif __AVX__ - __m256 vinv_l = _mm256_set1_ps(inv_l); - for (; k + 7 < out_embed_dim; k += 8) - { - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); - } -#elif __SSE2__ - __m128 vinv_l = _mm_set1_ps(inv_l); - for (; k + 3 < out_embed_dim; k += 4) - { - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(o_ptr + i * out_embed_dim + k), vinv_l)); - } -#endif - for (; k < out_embed_dim; k++) - { - outptr[k] = o_ptr[i * out_embed_dim + k] * inv_l; - } - } + top_blobs[1] = key; + top_blobs[2] = value; } - } - if (kv_cache) - { - top_blobs[1] = key; - top_blobs[2] = value; + return 0; } - return 0; + return sdpa_forward_prefill(query_ref, attn_mask_ref, key, value, top_blob, + top_blobs, opt, embed_dim, src_seqlen, num_heads, + num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, kv_cache, attn_mask, + use_bf16_path); } } // namespace ncnn From 99533524a3b705ceaf1951c1d8663c3a2583f772 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 21:33:51 +0800 Subject: [PATCH 38/53] refactor: remove duplicate BLOCK_N=128 in forward() decode path forward() already defines BLOCK_N=128 at the top; the inner if-block's redefinition was shadowing it unnecessarily. Now only 1 BLOCK_N per scope: - sdpa_decode(), sdpa_decode_chunk(), sdpa_forward_prefill() each have their own - forward() has a single shared BLOCK_N used by all internal paths --- src/layer/x86/sdpa_x86.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index f41b04208e9a..98ae6b857f3f 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -4473,7 +4473,6 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // FP32 optimized path using tiled GEMM + online softmax if (src_seqlen == 1) { - const int BLOCK_N = 128; const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; #if NCNN_BF16 From 3d065c279c46d97bd942d791244a7e7381ed4d86 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 21:36:24 +0800 Subject: [PATCH 39/53] refactor: sdpa_decode now calls sdpa_decode_chunk, eliminating ~45 lines of duplicate code sdpa_decode simply delegates to sdpa_decode_chunk(out, &m, &l, q, K, V, mask, 0, n, d, out_d, scale) then normalizes with 1/l. Also fixes a pre-existing unused-variable warning (qk_num_blocks). --- src/layer/x86/sdpa_x86.cpp | 46 +++----------------------------------- 1 file changed, 3 insertions(+), 43 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 98ae6b857f3f..e18d19047c4a 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -744,49 +744,8 @@ static void sdpa_decode(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) { - const int BLOCK_N = 128; -#if __AVX512F__ - __attribute__((aligned(64))) float s[BLOCK_N]; -#elif __AVX__ - __attribute__((aligned(32))) float s[BLOCK_N]; -#elif __SSE2__ - __attribute__((aligned(16))) float s[BLOCK_N]; -#else - float s[BLOCK_N]; -#endif - - vec_zero(out, out_d); - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < n; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, n - n_start); - - decode_qk_dot(s, q, K, n_start, block_n, d, scale); - - if (mask) - decode_mask_vec(s, mask + n_start, block_n); - - float tile_m = decode_max_vec(s, block_n); - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale(out, scale_factor, out_d); - } - - float l_add; - decode_exp_sum_vec(s, block_n, new_m, &l_add); - l += l_add; - - decode_pv_gemv(out, s, V, n_start, block_n, out_d); - - m = new_m; - } - + float m, l; + sdpa_decode_chunk(out, &m, &l, q, K, V, mask, 0, n, d, out_d, scale); float inv_l = 1.f / l; vec_scale(out, inv_l, out_d); } @@ -3955,6 +3914,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (use_int8_path) { const int qk_num_blocks = (embed_dim + 31) / 32; + (void)qk_num_blocks; const int v_num_blocks = (out_embed_dim + 31) / 32; Mat key_int8(embed_dim, dst_seqlen, num_group, 1u, opt.blob_allocator); From 596884f4f1029e41b153ebddfd1b78456812c7c4 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 21:44:55 +0800 Subject: [PATCH 40/53] refactor: extract sdpa_int8_decode_core helper, eliminate ~60 lines of duplicate N-loop code MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The int8 decode group_parallel and per-head paths had identical N-tile inner loops (decode_qk → mask → max → online softmax → pv_gemv). Extracted as sdpa_int8_decode_core(). Also fixed: - sdpa_decode forward-declares sdpa_decode_chunk (now calls it) - sdpa_forward_prefill suppresses unused-param warning for num_heads --- src/layer/x86/sdpa_x86.cpp | 106 +++++++++++++++++-------------------- 1 file changed, 48 insertions(+), 58 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index e18d19047c4a..30ec27ef272f 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -740,6 +740,11 @@ static inline void decode_exp_sum_vec(float* s, int block_n, float new_m, float* #endif } +static void sdpa_decode_chunk( + float* out, float* m_out, float* l_out, + const float* q, const float* K, const float* V, const float* mask, + int n_start, int n_end, int d, int out_d, float scale); + static void sdpa_decode(float* out, const float* q, const float* K, const float* V, const float* mask, int n, int d, int out_d, float scale) @@ -839,6 +844,32 @@ static void sdpa_decode_reduce( } } +static inline void sdpa_int8_decode_core( + float* s, float* out, float* m, float* l, + const signed char* q_int8, const float* q_scale, + const signed char* key_int8, const float* key_scales, + const signed char* value_int8, const float* value_scales, + const float* mask_ptr, + int n_start, int block_n, int embed_dim, int out_embed_dim, float scale) +{ + decode_qk_dot_int8(s, q_int8, key_int8, q_scale, key_scales, n_start, block_n, embed_dim, scale); + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(*m, tile_m); + if (*m != new_m) + { + float scale_factor = expf(*m - new_m); + *l *= scale_factor; + vec_scale(out, scale_factor, out_embed_dim); + } + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + *l += l_add; + decode_pv_gemv_int8(out, s, value_int8, value_scales, n_start, block_n, out_embed_dim); + *m = new_m; +} + #if __AVX512F__ static void qk_gemm_avx512(float* S, const float* Q, const float* K, @@ -3490,6 +3521,7 @@ static int sdpa_forward_prefill( int attn_mask, bool use_bf16_path) { + (void)num_heads; const int BLOCK_M = 64; const int BLOCK_N = 128; Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); @@ -4039,35 +4071,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { int block_n = std::min(BLOCK_N, dst_seqlen - n_start); - decode_qk_dot_int8(s, q_int8, - key_int8_head.row(0), - q_scale, - key_scales_head.row(0), - n_start, block_n, embed_dim, _scale); - - if (mask_ptr) - decode_mask_vec(s, mask_ptr + n_start, block_n); - - float tile_m = decode_max_vec(s, block_n); - - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale(out, scale_factor, out_embed_dim); - } - - float l_add; - decode_exp_sum_vec(s, block_n, new_m, &l_add); - l += l_add; - - decode_pv_gemv_int8(out, s, - value_int8_head.row(0), - value_scales_head.row(0), - n_start, block_n, out_embed_dim); - - m = new_m; + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); } float* outptr = top_blob_head.row(0); @@ -4121,35 +4132,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to { int block_n = std::min(BLOCK_N, dst_seqlen - n_start); - decode_qk_dot_int8(s, q_int8, - key_int8_head.row(0), - q_scale, - key_scales_head.row(0), - n_start, block_n, embed_dim, _scale); - - if (mask_ptr) - decode_mask_vec(s, mask_ptr + n_start, block_n); - - float tile_m = decode_max_vec(s, block_n); - - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - vec_scale(out, scale_factor, out_embed_dim); - } - - float l_add; - decode_exp_sum_vec(s, block_n, new_m, &l_add); - l += l_add; - - decode_pv_gemv_int8(out, s, - value_int8_head.row(0), - value_scales_head.row(0), - n_start, block_n, out_embed_dim); - - m = new_m; + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); } float* outptr = top_blob_head.row(0); From da81931f7608f1f07a73143f04a9c722388eb5b6 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Mon, 4 May 2026 23:56:10 +0800 Subject: [PATCH 41/53] refactor(sdpa_x86): extract top-level decode/prefill path functions from forward() MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract 4 top-level static functions to make forward() a thin dispatcher: - sdpa_decode_int8_x86() — INT8 decode (src_seqlen==1) - sdpa_prefill_int8_x86() — INT8 prefill (src_seqlen>1) - sdpa_decode_bf16s_x86() — BF16 decode - sdpa_decode_x86() — FP32 decode Also includes prior refactoring: - Remove redundant *_dispatch wrappers from sdpa_x86_bf16s.h - Unify BF16 decode to reuse generic FP32 helpers (decode_mask_vec, decode_max_vec, etc.) - Normalize BF16 micro-kernel naming with _kernel suffix - Extract INT8 KV quantization into sdpa_quantize_key_value_int8_x86() - Add missing include guard to sdpa_x86_int8.h No functional change; decode performance improved vs baseline. --- src/layer/x86/sdpa_x86.cpp | 2914 ++++++++++++++++++++------------ src/layer/x86/sdpa_x86_bf16s.h | 638 +------ src/layer/x86/sdpa_x86_int8.h | 5 + 3 files changed, 1855 insertions(+), 1702 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 30ec27ef272f..c1827835cea3 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3157,348 +3157,6 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, #endif // __SSE2__ -static inline void qk_gemm_dispatch(float* S, const float* Q, const float* K, - int m, int n, int d, float scale) -{ - if (d == 128) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<128>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<128>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<128>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 64) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<64>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<64>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<64>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 512) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<512>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<512>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<512>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 256) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<256>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<256>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<256>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 32) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<32>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<32>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<32>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 80) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<80>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<80>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<80>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 96) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<96>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<96>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<96>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 160) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<160>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<160>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<160>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 1024) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<1024>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<1024>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<1024>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 2048) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<2048>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<2048>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<2048>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 4096) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<4096>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<4096>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<4096>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 768) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<768>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<768>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<768>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 1536) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<1536>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<1536>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<1536>(S, Q, K, m, n, scale); - return; -#endif - } - if (d == 3072) - { -#if __AVX512F__ - qk_gemm_specialized_avx512<3072>(S, Q, K, m, n, scale); - return; -#elif __AVX__ - qk_gemm_specialized_avx<3072>(S, Q, K, m, n, scale); - return; -#elif __SSE2__ - qk_gemm_specialized_sse2<3072>(S, Q, K, m, n, scale); - return; -#endif - } - -#if __AVX512F__ - qk_gemm_avx512(S, Q, K, m, n, d, scale); -#elif __AVX__ - qk_gemm_avx(S, Q, K, m, n, d, scale); -#elif __SSE2__ - qk_gemm_sse2(S, Q, K, m, n, d, scale); -#else - qk_gemm_scalar(S, Q, K, m, n, d, scale); -#endif -} - -static inline void pv_gemm_dispatch(float* O, const float* P, const float* V, - int m, int n, int d) -{ - if (d == 128) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n); - return; -#elif __AVX__ - pv_gemm_avx<2, 128>(O, P, V, m, n); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 128>(O, P, V, m, n); - return; -#endif - } - if (d == 64) - { -#if __AVX512F__ - pv_gemm_avx512<4, 64>(O, P, V, m, n); - return; -#elif __AVX__ - pv_gemm_avx<4, 64>(O, P, V, m, n); - return; -#elif __SSE2__ - pv_gemm_sse2<4, 64>(O, P, V, m, n); - return; -#endif - } - if (d == 256) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 256); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 256); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 256); - return; -#endif - } - if (d == 512) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 512); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 512); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 512); - return; -#endif - } - if (d == 1024) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 1024); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 1024); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 1024); - return; -#endif - } - if (d == 2048) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 2048); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 2048); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 2048); - return; -#endif - } - if (d == 4096) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 4096); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 4096); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 4096); - return; -#endif - } - if (d == 768) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 768); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 768); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 768); - return; -#endif - } - if (d == 1536) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 1536); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 1536); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 1536); - return; -#endif - } - if (d == 3072) - { -#if __AVX512F__ - pv_gemm_avx512<2, 128>(O, P, V, m, n, 3072); - return; -#elif __AVX__ - pv_gemm_avx<2, 64>(O, P, V, m, n, 3072); - return; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(O, P, V, m, n, 3072); - return; -#endif - } - -#if __AVX512F__ - pv_gemm_avx512<4, 64>(O, P, V, m, n, d); -#elif __AVX__ - pv_gemm_avx<2, 32>(O, P, V, m, n, d); -#elif __SSE2__ - pv_gemm_sse2<2, 16>(O, P, V, m, n, d); -#else - pv_gemm_scalar(O, P, V, m, n, d); -#endif -} - // Timing instrumentation removed static int sdpa_forward_prefill( @@ -3619,18 +3277,211 @@ static int sdpa_forward_prefill( #if NCNN_BF16 if (use_bf16_path) { - qk_gemm_bf16s_dispatch(s_ptr, - q_batch_thread.row(0), - key_head.row(n_start), - block_m * num_heads_per_group, block_n, embed_dim, _scale); +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } + else +#endif + { + qk_gemm_bf16s_avx512_kernel(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); + } +#elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#else + qk_gemm_bf16s_scalar_kernel(s_ptr, + q_batch_thread.row(0), + key_head.row(n_start), + block_m * num_heads_per_group, block_n, embed_dim, _scale); +#endif } else #endif { - qk_gemm_dispatch(s_ptr, - q_batch_thread.row(0), - key_head.row(n_start), - block_m * num_heads_per_group, block_n, embed_dim, _scale); + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_specialized_avx512<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#else + qk_gemm_scalar(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); +#endif + break; + } } for (int hq = 0; hq < num_heads_per_group; hq++) @@ -3671,23 +3522,166 @@ static int sdpa_forward_prefill( #if NCNN_BF16 if (use_bf16_path) { - pv_gemm_bf16s_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } + else +#endif + { + pv_gemm_bf16s_avx512_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); + } +#elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#else + pv_gemm_bf16s_scalar_kernel(o_accum_thread.row(0), s_ptr, value_head.row(n_start), + block_m * num_heads_per_group, block_n, out_embed_dim); +#endif } else #endif { - pv_gemm_dispatch(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); - } - } - else - { - // large_dim: copy Q per head once, then loop over N-tiles - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#else + pv_gemm_scalar(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, out_embed_dim); +#endif + break; + } + } + } + else + { + // large_dim: copy Q per head once, then loop over N-tiles + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); float* q_dst = q_batch_thread.row(0); for (int i = 0; i < block_m; i++) @@ -3705,18 +3699,211 @@ static int sdpa_forward_prefill( #if NCNN_BF16 if (use_bf16_path) { - qk_gemm_bf16s_dispatch(s_head, - q_dst, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else +#endif + { + qk_gemm_bf16s_avx512_kernel(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } +#elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_bf16s_scalar_kernel(s_head, + q_dst, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); +#endif } else #endif { - qk_gemm_dispatch(s_head, - q_dst, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_specialized_avx512<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_scalar(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#endif + break; + } } if (attn_mask && mask_data[hq]) @@ -3752,14 +3939,157 @@ static int sdpa_forward_prefill( #if NCNN_BF16 if (use_bf16_path) { - pv_gemm_bf16s_dispatch(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + else +#endif + { + pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } +#elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#else + pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); +#endif } else #endif { - pv_gemm_dispatch(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#endif + break; + } } } } @@ -3821,53 +4151,1123 @@ static int sdpa_forward_prefill( return 0; } -int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const + +static int sdpa_quantize_key_value_int8_x86(const Mat& key, const Mat& value, + Mat& key_int8, Mat& key_scales, Mat& value_int8, Mat& value_scales, + int num_group, int dst_seqlen, int embed_dim, int out_embed_dim, int v_num_blocks, + bool cache_valid, int past_seqlen, bool kv_cache, + const Option& opt) { - Option opt = _opt; - if (int8_scale_term) + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) { - opt.use_packing_layout = false; // TODO enable packing - } - - const Mat& query = bottom_blobs[0]; - const Mat& cur_key = bottom_blobs[1]; - const Mat& cur_value = bottom_blobs[2]; - const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); - const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); - const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + const Mat key_head = key.channel(g); + Mat key_int8_head = key_int8.channel(g); + Mat key_scales_head = key_scales.channel(g); + int j_start = cache_valid ? past_seqlen : 0; + for (int j = j_start; j < dst_seqlen; j++) + { + dynamic_quantize_rowwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); + key_scales_head.row(j)[0] = 1.f / key_scales_head.row(j)[0]; + } - const int embed_dim = query.w; - const int src_seqlen = query.h; - const int num_heads = query.c; - const int cur_seqlen = cur_key.h; - const int num_group = cur_key.c; - const int out_embed_dim = cur_value.w; - const int past_seqlen = kv_cache ? past_key.h : 0; - const int dst_seqlen = past_seqlen + cur_seqlen; + if (kv_cache) + { + const Mat value_head = value.channel(g); + Mat value_int8_head = value_int8.channel(g); + Mat value_scales_head = value_scales.channel(g); + for (int j = j_start; j < dst_seqlen; j++) + { + dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); + for (int vb = 0; vb < v_num_blocks; vb++) + { + value_scales_head.row(j)[vb] = 1.f / value_scales_head.row(j)[vb]; + } + } + } + } + return 0; +} - const size_t elemsize = query.elemsize; +static int sdpa_decode_int8_x86( + const Mat& query, + const Mat& key_int8, + const Mat& key_scales, + const Mat& value_int8, + const Mat& value_scales, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask, + Mat& o_accum, + Mat& s_vec, + Mat& q_int8_tile, + Mat& q_scales_tile) +{ + const int BLOCK_N = 128; + // Decode path with dedicated int8 GEMV kernels + // For GQA/MQA, group-parallel reduces KV cache contention + const bool group_parallel = num_group >= opt.num_threads; - Mat key; - if (past_seqlen > 0) + if (group_parallel) { - key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); - if (key.empty()) - return -100; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_group; q++) + for (int g = 0; g < num_group; g++) { - const Mat past_key_head = past_key.channel(q); - const Mat cur_key_head = cur_key.channel(q); - Mat key_head = key.channel(q); + const Mat key_int8_head = key_int8.channel(g); + const Mat key_scales_head = key_scales.channel(g); + const Mat value_int8_head = value_int8.channel(g); + const Mat value_scales_head = value_scales.channel(g); - memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * elemsize); - memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * elemsize); - } - } - else - { - key = cur_key; + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query.channel(q); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale(outptr, inv_l, out_embed_dim); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); + float* q_scale = q_scales_tile.channel(get_omp_thread_num()); + dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); + q_scale[0] = 1.f / q_scale[0]; + + float* s = s_vec.channel(get_omp_thread_num()); + float* out = o_accum.channel(get_omp_thread_num()); + vec_zero(out, out_embed_dim); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + + sdpa_int8_decode_core(s, out, &m, &l, + q_int8, q_scale, + key_int8_head.row(0), + key_scales_head.row(0), + value_int8_head.row(0), + value_scales_head.row(0), + mask_ptr, + n_start, block_n, embed_dim, out_embed_dim, _scale); + } + + float* outptr = top_blob_head.row(0); + float inv_l = 1.f / l; + memcpy(outptr, out, out_embed_dim * sizeof(float)); + vec_scale(outptr, inv_l, out_embed_dim); + } + } + + return 0; +} + +static int sdpa_prefill_int8_x86( + const Mat& query, + const Mat& key_int8, + const Mat& key_scales, + const Mat& value_int8, + const Mat& value_scales, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int out_embed_dim, + int src_seqlen, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask, + int kv_cache, + Mat& o_accum, + Mat& s_vec, + Mat& p_vec, + Mat& q_int8_tile, + Mat& q_scales_tile, + const Mat& value) +{ + const int BLOCK_M = 64; + const int BLOCK_N = 128; +#pragma omp parallel for num_threads(opt.num_threads) +for (int q = 0; q < num_heads; q++) +{ + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); + + Mat mask_head; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + } + + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); + + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + { + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; + + for (int i = 0; i < block_m; i++) + { + dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + vec_zero(optr, out_embed_dim); + } + + float m_vec[BLOCK_M]; + float l_vec[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + if (block_m == 1) + { + qk_int8_gemm_row(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0)[0], + key_scales_head.row(n_start), + block_n, embed_dim, _scale); + } + else + { + qk_int8_gemm_tiled(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } + + for (int i = 0; i < block_m; i++) + { + float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); + float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); + + if (attn_mask) + decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); + } + float m_new = _mm512_comp_reduce_max_ps(vmax); +#elif __AVX__ + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); +#elif __SSE2__ + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); +#else + float m_new = m_vec[i]; + for (int j = 0; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); +#endif + + float scale_factor = expf(m_vec[i] - m_new); + float l_new = l_vec[i] * scale_factor; + + float* optr = o_accum_head.row(i); + vec_scale(optr, scale_factor, out_embed_dim); + +#if __AVX512F__ + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); + _mm512_storeu_ps(p_row + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); + _mm512_mask_storeu_ps(p_row + j, mask, pvec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); + } + l_new += _mm512_comp_reduce_add_ps(vsum); +#elif __AVX__ + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); + _mm256_storeu_ps(p_row + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + l_new += _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } +#elif __SSE2__ + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); + _mm_storeu_ps(p_row + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + l_new += _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } +#else + for (int j = 0; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } +#endif + + m_vec[i] = m_new; + l_vec[i] = l_new; + } + + if (kv_cache) + { + pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, + value_int8_head.row(n_start), + value_scales_head.row(n_start), + block_m, block_n, out_embed_dim); + } + else + { + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#else + pv_gemm_scalar(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); +#endif + break; + } + } + } + + for (int i = 0; i < block_m; i++) + { + float* optr = o_accum_head.row(i); + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); + } +#elif __AVX__ + __m256 vinv_l256 = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); +#elif __SSE2__ + __m128 vinv_l128 = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); +#endif + for (; k < out_embed_dim; k++) + outptr[k] = optr[k] * inv_l; + } + } +} + + return 0; +} + +static int sdpa_decode_bf16s_x86( + const Mat& query_ref, + const Mat& key, + const Mat& value, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask) +{ + const int BLOCK_N = 128; + const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; + if (use_split_kv) + { + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + + const float* qptr = query_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float* p = partials.channel(q).row(chunk); + float* out = p + 2; + float* m_out = p; + float* l_out = p + 1; + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(out, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n = n_start; n < n_end; n += BLOCK_N) + { + int block_n = std::min(BLOCK_N, n_end - n); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(out, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(out, s, Vptr, n, block_n, out_embed_dim); +#endif + + m = new_m; + } + + *m_out = m; + *l_out = l; + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce_bf16s(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } + } + else + { + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(outptr, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(outptr, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#endif + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale(outptr, inv_l, out_embed_dim); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const unsigned short* Kptr = key_head.row(0); + const unsigned short* Vptr = value_head.row(0); + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + const int BLOCK_N = 128; +#if __AVX512F__ + __attribute__((aligned(64))) float s[BLOCK_N]; +#elif __AVX__ + __attribute__((aligned(32))) float s[BLOCK_N]; +#elif __SSE2__ + __attribute__((aligned(16))) float s[BLOCK_N]; +#else + float s[BLOCK_N]; +#endif + + vec_zero(outptr, out_embed_dim); + float m = -FLT_MAX; + float l = 0.f; + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int block_n = std::min(BLOCK_N, dst_seqlen - n_start); + +#if __AVX512F__ + decode_qk_dot_bf16s_avx512_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __AVX__ + decode_qk_dot_bf16s_avx_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#elif __SSE2__ + decode_qk_dot_bf16s_sse2_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#else + decode_qk_dot_bf16s_scalar_kernel(s, qptr, Kptr, n_start, block_n, embed_dim, _scale); +#endif + + if (mask_ptr) + decode_mask_vec(s, mask_ptr + n_start, block_n); + + float tile_m = decode_max_vec(s, block_n); + float new_m = std::max(m, tile_m); + if (m != new_m) + { + float scale_factor = expf(m - new_m); + l *= scale_factor; + vec_scale(outptr, scale_factor, out_embed_dim); + } + + float l_add; + decode_exp_sum_vec(s, block_n, new_m, &l_add); + l += l_add; + +#if __AVX512F__ + decode_pv_gemv_bf16s_avx512_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __AVX__ + decode_pv_gemv_bf16s_avx_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#elif __SSE2__ + decode_pv_gemv_bf16s_sse2_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#else + decode_pv_gemv_bf16s_scalar_kernel(outptr, s, Vptr, n_start, block_n, out_embed_dim); +#endif + + m = new_m; + } + + float inv_l = 1.f / l; + vec_scale(outptr, inv_l, out_embed_dim); + } + } + } + + return 0; +} + +static int sdpa_decode_x86( + const Mat& query_ref, + const Mat& key, + const Mat& value, + Mat& top_blob, + const Mat& attn_mask_ref, + const Option& opt, + int embed_dim, + int num_heads, + int num_group, + int out_embed_dim, + int dst_seqlen, + int num_heads_per_group, + float _scale, + int attn_mask) +{ + const int BLOCK_N = 128; + const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; +if (use_split_kv) +{ + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; + + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + + const float* qptr = query_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } +} +else +{ + // Decode path: fused GEMV kernel for single-query attention + const bool group_parallel = num_group >= opt.num_threads; + + if (group_parallel) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } +} + + return 0; +} + +int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& _opt) const +{ + Option opt = _opt; + if (int8_scale_term) + { + opt.use_packing_layout = false; // TODO enable packing + } + + const Mat& query = bottom_blobs[0]; + const Mat& cur_key = bottom_blobs[1]; + const Mat& cur_value = bottom_blobs[2]; + const Mat& attn_mask_blob = attn_mask ? bottom_blobs[3] : Mat(); + const Mat& past_key = kv_cache ? bottom_blobs[attn_mask ? 4 : 3] : Mat(); + const Mat& past_value = kv_cache ? bottom_blobs[attn_mask ? 5 : 4] : Mat(); + + const int embed_dim = query.w; + const int src_seqlen = query.h; + const int num_heads = query.c; + const int cur_seqlen = cur_key.h; + const int num_group = cur_key.c; + const int out_embed_dim = cur_value.w; + const int past_seqlen = kv_cache ? past_key.h : 0; + const int dst_seqlen = past_seqlen + cur_seqlen; + + const size_t elemsize = query.elemsize; + + Mat key; + if (past_seqlen > 0) + { + key.create(embed_dim, dst_seqlen, num_group, elemsize, opt.blob_allocator); + if (key.empty()) + return -100; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_group; q++) + { + const Mat past_key_head = past_key.channel(q); + const Mat cur_key_head = cur_key.channel(q); + Mat key_head = key.channel(q); + + memcpy(key_head.row(0), past_key_head, embed_dim * past_seqlen * elemsize); + memcpy(key_head.row(past_seqlen), cur_key_head, embed_dim * cur_seqlen * elemsize); + } + } + else + { + key = cur_key; } Mat value; @@ -3979,34 +5379,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } } - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) - { - const Mat key_head = key.channel(g); - Mat key_int8_head = key_int8.channel(g); - Mat key_scales_head = key_scales.channel(g); - int j_start = cache_valid ? past_seqlen : 0; - for (int j = j_start; j < dst_seqlen; j++) - { - dynamic_quantize_rowwise(key_head.row(j), key_int8_head.row(j), key_scales_head.row(j), embed_dim); - key_scales_head.row(j)[0] = 1.f / key_scales_head.row(j)[0]; - } - - if (kv_cache) - { - const Mat value_head = value.channel(g); - Mat value_int8_head = value_int8.channel(g); - Mat value_scales_head = value_scales.channel(g); - for (int j = j_start; j < dst_seqlen; j++) - { - dynamic_quantize_blockwise(value_head.row(j), value_int8_head.row(j), value_scales_head.row(j), out_embed_dim); - for (int vb = 0; vb < v_num_blocks; vb++) - { - value_scales_head.row(j)[vb] = 1.f / value_scales_head.row(j)[vb]; - } - } - } - } + sdpa_quantize_key_value_int8_x86(key, value, key_int8, key_scales, value_int8, value_scales, + num_group, dst_seqlen, embed_dim, out_embed_dim, v_num_blocks, + cache_valid, past_seqlen, kv_cache, opt); Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); @@ -4018,136 +5393,13 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return -100; if (src_seqlen == 1) - { - // Decode path with dedicated int8 GEMV kernels - // For GQA/MQA, group-parallel reduces KV cache contention - const bool group_parallel = num_group >= opt.num_threads; - - if (group_parallel) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) - { - const Mat key_int8_head = key_int8.channel(g); - const Mat key_scales_head = key_scales.channel(g); - const Mat value_int8_head = value_int8.channel(g); - const Mat value_scales_head = value_scales.channel(g); - - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query.channel(q); - Mat top_blob_head = top_blob.channel(q); - - signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); - float* q_scale = q_scales_tile.channel(get_omp_thread_num()); - dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); - q_scale[0] = 1.f / q_scale[0]; - - float* s = s_vec.channel(get_omp_thread_num()); - float* out = o_accum.channel(get_omp_thread_num()); - vec_zero(out, out_embed_dim); - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, dst_seqlen - n_start); - - sdpa_int8_decode_core(s, out, &m, &l, - q_int8, q_scale, - key_int8_head.row(0), - key_scales_head.row(0), - value_int8_head.row(0), - value_scales_head.row(0), - mask_ptr, - n_start, block_n, embed_dim, out_embed_dim, _scale); - } - - float* outptr = top_blob_head.row(0); - float inv_l = 1.f / l; - memcpy(outptr, out, out_embed_dim * sizeof(float)); - vec_scale(outptr, inv_l, out_embed_dim); - } - } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - const Mat query_head = query.channel(q); - const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); - const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); - const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); - const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); - Mat top_blob_head = top_blob.channel(q); - - signed char* q_int8 = q_int8_tile.channel(get_omp_thread_num()); - float* q_scale = q_scales_tile.channel(get_omp_thread_num()); - dynamic_quantize_rowwise(query_head.row(0), q_int8, q_scale, embed_dim); - q_scale[0] = 1.f / q_scale[0]; - - float* s = s_vec.channel(get_omp_thread_num()); - float* out = o_accum.channel(get_omp_thread_num()); - vec_zero(out, out_embed_dim); - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, dst_seqlen - n_start); - - sdpa_int8_decode_core(s, out, &m, &l, - q_int8, q_scale, - key_int8_head.row(0), - key_scales_head.row(0), - value_int8_head.row(0), - value_scales_head.row(0), - mask_ptr, - n_start, block_n, embed_dim, out_embed_dim, _scale); - } - - float* outptr = top_blob_head.row(0); - float inv_l = 1.f / l; - memcpy(outptr, out, out_embed_dim * sizeof(float)); - vec_scale(outptr, inv_l, out_embed_dim); - } - } + { + int ret = sdpa_decode_int8_x86(query, key_int8, key_scales, value_int8, value_scales, + top_blob, attn_mask_ref, opt, embed_dim, num_heads, num_group, out_embed_dim, + dst_seqlen, num_heads_per_group, _scale, attn_mask, + o_accum, s_vec, q_int8_tile, q_scales_tile); + if (ret != 0) + return ret; if (kv_cache && dst_seqlen > 0) { @@ -4163,239 +5415,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - // else: fall through to fp32 decode path below - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) + else { - const Mat query_head = query.channel(q); - const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); - const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); - const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); - const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); - Mat top_blob_head = top_blob.channel(q); - - Mat mask_head; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - } - - Mat o_accum_head = o_accum.channel(get_omp_thread_num()); - float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - float* p_vec_ptr = p_vec.row(get_omp_thread_num()); - Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); - Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); - - for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) - { - int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; - int block_m = m_end - m_start; - - for (int i = 0; i < block_m; i++) - { - dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); - q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; - } - - for (int i = 0; i < block_m; i++) - { - float* optr = o_accum_head.row(i); - vec_zero(optr, out_embed_dim); - } - - float m_vec[BLOCK_M]; - float l_vec[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_vec[i] = -FLT_MAX; - l_vec[i] = 0.f; - } - - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; - int block_n = n_end - n_start; - - if (block_m == 1) - { - qk_int8_gemm_row(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0)[0], - key_scales_head.row(n_start), - block_n, embed_dim, _scale); - } - else - { - qk_int8_gemm_tiled(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0), - key_scales_head.row(n_start), - block_m, block_n, embed_dim, _scale); - } - - for (int i = 0; i < block_m; i++) - { - float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); - float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); - - if (attn_mask) - decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); - -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(m_vec[i]); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); - } - float m_new = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(m_vec[i]); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); - float m_new = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(m_vec[i]); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); - float m_new = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#else - float m_new = m_vec[i]; - for (int j = 0; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#endif - - float scale_factor = expf(m_vec[i] - m_new); - float l_new = l_vec[i] * scale_factor; - - float* optr = o_accum_head.row(i); - vec_scale(optr, scale_factor, out_embed_dim); - -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(m_new); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); - _mm512_storeu_ps(p_row + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); - _mm512_mask_storeu_ps(p_row + j, mask, pvec); - vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); - } - l_new += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(m_new); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); - _mm256_storeu_ps(p_row + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - l_new += _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(m_new); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); - _mm_storeu_ps(p_row + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - l_new += _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#else - for (int j = 0; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#endif - - m_vec[i] = m_new; - l_vec[i] = l_new; - } - - if (kv_cache) - { - pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, - value_int8_head.row(n_start), - value_scales_head.row(n_start), - block_m, block_n, out_embed_dim); - } - else - { - pv_gemm_dispatch(o_accum_head.row(0), p_vec_ptr, - value.channel(q / num_heads_per_group).row(n_start), - block_m, block_n, out_embed_dim); - } - } - - for (int i = 0; i < block_m; i++) - { - float* optr = o_accum_head.row(i); - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - int k = 0; -#if __AVX512F__ - __m512 vinv_l = _mm512_set1_ps(inv_l); - for (; k + 15 < out_embed_dim; k += 16) - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); - if (k < out_embed_dim) - { - __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); - } -#elif __AVX__ - __m256 vinv_l256 = _mm256_set1_ps(inv_l); - for (; k + 7 < out_embed_dim; k += 8) - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); -#elif __SSE2__ - __m128 vinv_l128 = _mm_set1_ps(inv_l); - for (; k + 3 < out_embed_dim; k += 4) - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); -#endif - for (; k < out_embed_dim; k++) - outptr[k] = optr[k] * inv_l; - } - } + int ret = sdpa_prefill_int8_x86(query, key_int8, key_scales, value_int8, value_scales, + top_blob, attn_mask_ref, opt, embed_dim, num_heads, out_embed_dim, src_seqlen, + dst_seqlen, num_heads_per_group, _scale, attn_mask, kv_cache, + o_accum, s_vec, p_vec, q_int8_tile, q_scales_tile, value); + if (ret != 0) + return ret; } if (kv_cache) @@ -4423,145 +5450,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to // FP32 optimized path using tiled GEMM + online softmax if (src_seqlen == 1) { - const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; - #if NCNN_BF16 if (use_bf16_path) { - if (use_split_kv) - { - const int num_kv_chunks = opt.num_threads; - Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); - if (partials.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int task = 0; task < num_heads * num_kv_chunks; task++) - { - int q = task / num_kv_chunks; - int chunk = task % num_kv_chunks; - - int n_start = chunk * dst_seqlen / num_kv_chunks; - int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; - - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - - const float* qptr = query_head.row(0); - const unsigned short* Kptr = key_head.row(0); - const unsigned short* Vptr = value_head.row(0); - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - float* p = partials.channel(q).row(chunk); - sdpa_decode_chunk_bf16s_dispatch(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, - n_start, n_end, embed_dim, out_embed_dim, _scale); - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - Mat top_blob_head = top_blob.channel(q); - float* outptr = top_blob_head.row(0); - sdpa_decode_reduce_bf16s_dispatch(outptr, out_embed_dim, - partials.channel(q), num_kv_chunks, 2 + out_embed_dim); - } - } - else - { - const bool group_parallel = num_group >= opt.num_threads; - - if (group_parallel) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) - { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const unsigned short* Kptr = key_head.row(0); - const unsigned short* Vptr = value_head.row(0); - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode_bf16s_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const unsigned short* Kptr = key_head.row(0); - const unsigned short* Vptr = value_head.row(0); - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode_bf16s_dispatch(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } + int ret = sdpa_decode_bf16s_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); + if (ret != 0) + return ret; if (kv_cache) { @@ -4573,141 +5469,11 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } #endif // NCNN_BF16 - if (use_split_kv) - { - const int num_kv_chunks = opt.num_threads; - Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); - if (partials.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int task = 0; task < num_heads * num_kv_chunks; task++) - { - int q = task / num_kv_chunks; - int chunk = task % num_kv_chunks; - - int n_start = chunk * dst_seqlen / num_kv_chunks; - int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; - - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - - const float* qptr = query_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - float* p = partials.channel(q).row(chunk); - sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, - n_start, n_end, embed_dim, out_embed_dim, _scale); - } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - Mat top_blob_head = top_blob.channel(q); - float* outptr = top_blob_head.row(0); - sdpa_decode_reduce(outptr, out_embed_dim, - partials.channel(q), num_kv_chunks, 2 + out_embed_dim); - } - } - else - { - // Decode path: fused GEMV kernel for single-query attention - const bool group_parallel = num_group >= opt.num_threads; - - if (group_parallel) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) - { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } - } + int ret = sdpa_decode_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); + if (ret != 0) + return ret; if (kv_cache) { diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h index b26a70706ae3..60765374b34e 100644 --- a/src/layer/x86/sdpa_x86_bf16s.h +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -409,522 +409,9 @@ static inline void decode_pv_gemv_bf16s_scalar_kernel(float* out, const float* s // --------------------------------------------------------------------------- -// sdpa_decode_bf16s : full decode with bf16 K/V +// sdpa_decode_reduce_bf16s : reduce partial results from split-kv chunks // --------------------------------------------------------------------------- -static inline void sdpa_decode_bf16s(float* out, const float* q, - const unsigned short* K, const unsigned short* V, const float* mask, - int n, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; -#if __AVX512F__ - __attribute__((aligned(64))) float s[BLOCK_N]; -#elif __AVX__ - __attribute__((aligned(32))) float s[BLOCK_N]; -#elif __SSE2__ - __attribute__((aligned(16))) float s[BLOCK_N]; -#else - float s[BLOCK_N]; -#endif - - // vec_zero - { -#if __AVX512F__ - __m512 zero512 = _mm512_setzero_ps(); - int i = 0; - for (; i + 15 < out_d; i += 16) - _mm512_storeu_ps(out + i, zero512); - if (i < out_d) - { - __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); - _mm512_mask_storeu_ps(out + i, mask, zero512); - } -#else - int i = 0; -#if __AVX__ - __m256 zero256 = _mm256_setzero_ps(); - for (; i + 7 < out_d; i += 8) - _mm256_storeu_ps(out + i, zero256); -#endif -#if __SSE2__ - __m128 zero128 = _mm_setzero_ps(); - for (; i + 3 < out_d; i += 4) - _mm_storeu_ps(out + i, zero128); -#endif - for (; i < out_d; i++) - out[i] = 0.f; -#endif - } - - float m = -FLT_MAX; - float l = 0.f; - - for (int n_start = 0; n_start < n; n_start += BLOCK_N) - { - int block_n = std::min(BLOCK_N, n - n_start); - -#if __AVX512F__ - decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); -#elif __AVX__ - decode_qk_dot_bf16s_avx_kernel(s, q, K, n_start, block_n, d, scale); -#elif __SSE2__ - decode_qk_dot_bf16s_sse2_kernel(s, q, K, n_start, block_n, d, scale); -#else - decode_qk_dot_bf16s_scalar_kernel(s, q, K, n_start, block_n, d, scale); -#endif - - if (mask) - { -#if __AVX512F__ - int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n_start + j))); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n_start + j))); - } -#elif __AVX__ - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; -#elif __SSE2__ - int j = 0; - for (; j + 3 < block_n; j += 4) - _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n_start + j))); - for (; j < block_n; j++) - s[j] += mask[n_start + j]; -#else - for (int j = 0; j < block_n; j++) - s[j] += mask[n_start + j]; -#endif - } - - // tile max -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); - } - float tile_m = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); - float tile_m = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#else - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#endif - - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - // vec_scale(out, scale_factor, out_d); - { -#if __AVX512F__ - __m512 vscale512 = _mm512_set1_ps(scale_factor); - int i = 0; - for (; i + 15 < out_d; i += 16) - _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); - if (i < out_d) - { - __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); - _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); - } -#else - int i = 0; -#if __AVX__ - __m256 vscale256 = _mm256_set1_ps(scale_factor); - for (; i + 7 < out_d; i += 8) - _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); -#endif -#if __SSE2__ - __m128 vscale128 = _mm_set1_ps(scale_factor); - for (; i + 3 < out_d; i += 4) - _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); -#endif - for (; i < out_d; i++) - out[i] *= scale_factor; -#endif - } - } - - // exp and sum -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); - } - l += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(new_m); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); - _mm_storeu_ps(s + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#else - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#endif - -#if __AVX512F__ - decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); -#elif __AVX__ - decode_pv_gemv_bf16s_avx_kernel(out, s, V, n_start, block_n, out_d); -#elif __SSE2__ - decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n_start, block_n, out_d); -#else - decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n_start, block_n, out_d); -#endif - - m = new_m; - } - - float inv_l = 1.f / l; - // vec_scale(out, inv_l, out_d); - { -#if __AVX512F__ - __m512 vscale512 = _mm512_set1_ps(inv_l); - int i = 0; - for (; i + 15 < out_d; i += 16) - _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); - if (i < out_d) - { - __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); - _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); - } -#else - int i = 0; -#if __AVX__ - __m256 vscale256 = _mm256_set1_ps(inv_l); - for (; i + 7 < out_d; i += 8) - _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); -#endif -#if __SSE2__ - __m128 vscale128 = _mm_set1_ps(inv_l); - for (; i + 3 < out_d; i += 4) - _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); -#endif - for (; i < out_d; i++) - out[i] *= inv_l; -#endif - } -} - - -// --------------------------------------------------------------------------- -// sdpa_decode_chunk_bf16s / sdpa_decode_reduce_bf16s -// --------------------------------------------------------------------------- - -static inline void sdpa_decode_chunk_bf16s( - float* out, float* m_out, float* l_out, - const float* q, const unsigned short* K, const unsigned short* V, const float* mask, - int n_start, int n_end, int d, int out_d, float scale) -{ - const int BLOCK_N = 128; -#if __AVX512F__ - __attribute__((aligned(64))) float s[BLOCK_N]; -#elif __AVX__ - __attribute__((aligned(32))) float s[BLOCK_N]; -#elif __SSE2__ - __attribute__((aligned(16))) float s[BLOCK_N]; -#else - float s[BLOCK_N]; -#endif - - // vec_zero - { -#if __AVX512F__ - __m512 zero512 = _mm512_setzero_ps(); - int i = 0; - for (; i + 15 < out_d; i += 16) - _mm512_storeu_ps(out + i, zero512); - if (i < out_d) - { - __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); - _mm512_mask_storeu_ps(out + i, mask, zero512); - } -#else - int i = 0; -#if __AVX__ - __m256 zero256 = _mm256_setzero_ps(); - for (; i + 7 < out_d; i += 8) - _mm256_storeu_ps(out + i, zero256); -#endif -#if __SSE2__ - __m128 zero128 = _mm_setzero_ps(); - for (; i + 3 < out_d; i += 4) - _mm_storeu_ps(out + i, zero128); -#endif - for (; i < out_d; i++) - out[i] = 0.f; -#endif - } - - float m = -FLT_MAX; - float l = 0.f; - - for (int n = n_start; n < n_end; n += BLOCK_N) - { - int block_n = std::min(BLOCK_N, n_end - n); - -#if __AVX512F__ - decode_qk_dot_bf16s_avx512_kernel(s, q, K, n, block_n, d, scale); -#elif __AVX__ - decode_qk_dot_bf16s_avx_kernel(s, q, K, n, block_n, d, scale); -#elif __SSE2__ - decode_qk_dot_bf16s_sse2_kernel(s, q, K, n, block_n, d, scale); -#else - decode_qk_dot_bf16s_scalar_kernel(s, q, K, n, block_n, d, scale); -#endif - - if (mask) - { -#if __AVX512F__ - int j = 0; - for (; j + 15 < block_n; j += 16) - _mm512_storeu_ps(s + j, _mm512_add_ps(_mm512_loadu_ps(s + j), _mm512_loadu_ps(mask + n + j))); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - _mm512_mask_storeu_ps(s + j, mask_n, - _mm512_add_ps(_mm512_maskz_loadu_ps(mask_n, s + j), _mm512_maskz_loadu_ps(mask_n, mask + n + j))); - } -#elif __AVX__ - int j = 0; - for (; j + 7 < block_n; j += 8) - _mm256_storeu_ps(s + j, _mm256_add_ps(_mm256_loadu_ps(s + j), _mm256_loadu_ps(mask + n + j))); - for (; j < block_n; j++) - s[j] += mask[n + j]; -#elif __SSE2__ - int j = 0; - for (; j + 3 < block_n; j += 4) - _mm_storeu_ps(s + j, _mm_add_ps(_mm_loadu_ps(s + j), _mm_loadu_ps(mask + n + j))); - for (; j < block_n; j++) - s[j] += mask[n + j]; -#else - for (int j = 0; j < block_n; j++) - s[j] += mask[n + j]; -#endif - } - - // tile max -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s + j)); - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask_n, s + j)); - } - float tile_m = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s + j)); - float tile_m = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(-FLT_MAX); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s + j)); - float tile_m = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#else - float tile_m = -FLT_MAX; - for (int j = 0; j < block_n; j++) - tile_m = std::max(tile_m, s[j]); -#endif - - float new_m = std::max(m, tile_m); - if (m != new_m) - { - float scale_factor = expf(m - new_m); - l *= scale_factor; - // vec_scale - { -#if __AVX512F__ - __m512 vscale512 = _mm512_set1_ps(scale_factor); - int i = 0; - for (; i + 15 < out_d; i += 16) - _mm512_storeu_ps(out + i, _mm512_mul_ps(_mm512_loadu_ps(out + i), vscale512)); - if (i < out_d) - { - __mmask16 mask = (__mmask16)((1u << (out_d - i)) - 1); - _mm512_mask_storeu_ps(out + i, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, out + i), vscale512)); - } -#else - int i = 0; -#if __AVX__ - __m256 vscale256 = _mm256_set1_ps(scale_factor); - for (; i + 7 < out_d; i += 8) - _mm256_storeu_ps(out + i, _mm256_mul_ps(_mm256_loadu_ps(out + i), vscale256)); -#endif -#if __SSE2__ - __m128 vscale128 = _mm_set1_ps(scale_factor); - for (; i + 3 < out_d; i += 4) - _mm_storeu_ps(out + i, _mm_mul_ps(_mm_loadu_ps(out + i), vscale128)); -#endif - for (; i < out_d; i++) - out[i] *= scale_factor; -#endif - } - } - - // exp and sum -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(new_m); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s + j), vm_new)); - _mm512_storeu_ps(s + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask_n = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask_n, s + j), vm_new)); - _mm512_mask_storeu_ps(s + j, mask_n, pvec); - vsum = _mm512_mask_add_ps(vsum, mask_n, vsum, pvec); - } - l += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(new_m); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s + j), vm_new)); - _mm256_storeu_ps(s + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - float l_add = _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(new_m); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s + j), vm_new)); - _mm_storeu_ps(s + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - float l_add = _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#else - float l_add = 0.f; - for (int j = 0; j < block_n; j++) - { - s[j] = expf(s[j] - new_m); - l_add += s[j]; - } - l += l_add; -#endif - -#if __AVX512F__ - decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n, block_n, out_d); -#elif __AVX__ - decode_pv_gemv_bf16s_avx_kernel(out, s, V, n, block_n, out_d); -#elif __SSE2__ - decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n, block_n, out_d); -#else - decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n, block_n, out_d); -#endif - - m = new_m; - } - - *m_out = m; - *l_out = l; -} - static inline void sdpa_decode_reduce_bf16s( float* out, int out_d, const float* partials, int num_chunks, int partial_stride) @@ -1022,7 +509,7 @@ static inline void sdpa_decode_reduce_bf16s( // --------------------------------------------------------------------------- #if __AVX512F__ -static void qk_gemm_bf16s_avx512(float* S, const float* Q, const unsigned short* K, +static void qk_gemm_bf16s_avx512_kernel(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { int i = 0; @@ -1265,7 +752,7 @@ static void qk_gemm_bf16s_avx512(float* S, const float* Q, const unsigned short* #if __AVX__ -static void qk_gemm_bf16s_avx(float* S, const float* Q, const unsigned short* K, +static void qk_gemm_bf16s_avx_kernel(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { int i = 0; @@ -1392,7 +879,7 @@ static void qk_gemm_bf16s_avx(float* S, const float* Q, const unsigned short* K, #endif // __AVX__ #if __SSE2__ -static void qk_gemm_bf16s_sse2(float* S, const float* Q, const unsigned short* K, +static void qk_gemm_bf16s_sse2_kernel(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { for (int i = 0; i < m; i++) @@ -1442,7 +929,7 @@ static void qk_gemm_bf16s_sse2(float* S, const float* Q, const unsigned short* K } #endif // __SSE2__ -static void qk_gemm_bf16s_scalar(float* S, const float* Q, const unsigned short* K, +static void qk_gemm_bf16s_scalar_kernel(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { for (int i = 0; i < m; i++) @@ -1465,7 +952,7 @@ static void qk_gemm_bf16s_scalar(float* S, const float* Q, const unsigned short* // --------------------------------------------------------------------------- #if __AVX512F__ -static void pv_gemm_bf16s_avx512(float* O, const float* P, const unsigned short* V, int m, int n, int d) +static void pv_gemm_bf16s_avx512_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) { int dd = 0; for (; dd + 127 < d; dd += 128) @@ -1622,7 +1109,7 @@ static void pv_gemm_bf16s_avx512(float* O, const float* P, const unsigned short* #endif // __AVX512F__ #if __AVX__ -static void pv_gemm_bf16s_avx(float* O, const float* P, const unsigned short* V, int m, int n, int d) +static void pv_gemm_bf16s_avx_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) { int dd = 0; for (; dd + 31 < d; dd += 32) @@ -1751,7 +1238,7 @@ static void pv_gemm_bf16s_avx(float* O, const float* P, const unsigned short* V, #endif // __AVX__ #if __SSE2__ -static void pv_gemm_bf16s_sse2(float* O, const float* P, const unsigned short* V, int m, int n, int d) +static void pv_gemm_bf16s_sse2_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) { int dd = 0; for (; dd + 15 < d; dd += 16) @@ -1809,7 +1296,7 @@ static void pv_gemm_bf16s_sse2(float* O, const float* P, const unsigned short* V } #endif // __SSE2__ -static void pv_gemm_bf16s_scalar(float* O, const float* P, const unsigned short* V, int m, int n, int d) +static void pv_gemm_bf16s_scalar_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) { for (int i = 0; i < m; i++) { @@ -2143,7 +1630,7 @@ static void qk_gemm_bf16s_avx512bf16_kernel(float* S, const float* Q, const unsi if (i < m) { - qk_gemm_bf16s_avx512(S + i * n, Q + i * d, K, m - i, n, d, scale); + qk_gemm_bf16s_avx512_kernel(S + i * n, Q + i * d, K, m - i, n, d, scale); } _mm_free(q_bf16); @@ -2757,109 +2244,4 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un -// --------------------------------------------------------------------------- -// Dispatch wrappers -// --------------------------------------------------------------------------- - -static inline void decode_qk_dot_bf16s(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale) -{ -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) - { - decode_qk_dot_bf16s_avx512bf16(s, q, K, n_start, block_n, d, scale); - return; - } -#endif - decode_qk_dot_bf16s_avx512_kernel(s, q, K, n_start, block_n, d, scale); -#elif __AVX__ - decode_qk_dot_bf16s_avx_kernel(s, q, K, n_start, block_n, d, scale); -#elif __SSE2__ - decode_qk_dot_bf16s_sse2_kernel(s, q, K, n_start, block_n, d, scale); -#else - decode_qk_dot_bf16s_scalar_kernel(s, q, K, n_start, block_n, d, scale); -#endif -} - -static inline void decode_pv_gemv_bf16s(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d) -{ -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) - { - decode_pv_gemv_bf16s_avx512bf16(out, s, V, n_start, block_n, out_d); - return; - } -#endif - decode_pv_gemv_bf16s_avx512_kernel(out, s, V, n_start, block_n, out_d); -#elif __AVX__ - decode_pv_gemv_bf16s_avx_kernel(out, s, V, n_start, block_n, out_d); -#elif __SSE2__ - decode_pv_gemv_bf16s_sse2_kernel(out, s, V, n_start, block_n, out_d); -#else - decode_pv_gemv_bf16s_scalar_kernel(out, s, V, n_start, block_n, out_d); -#endif -} - -static inline void qk_gemm_bf16s_dispatch(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) -{ -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) - { - qk_gemm_bf16s_avx512bf16(S, Q, K, m, n, d, scale); - return; - } -#endif - qk_gemm_bf16s_avx512(S, Q, K, m, n, d, scale); -#elif __AVX__ - qk_gemm_bf16s_avx(S, Q, K, m, n, d, scale); -#elif __SSE2__ - qk_gemm_bf16s_sse2(S, Q, K, m, n, d, scale); -#else - qk_gemm_bf16s_scalar(S, Q, K, m, n, d, scale); -#endif -} - -static inline void pv_gemm_bf16s_dispatch(float* O, const float* P, const unsigned short* V, int m, int n, int d) -{ -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) - { - pv_gemm_bf16s_avx512bf16(O, P, V, m, n, d); - return; - } -#endif - pv_gemm_bf16s_avx512(O, P, V, m, n, d); -#elif __AVX__ - pv_gemm_bf16s_avx(O, P, V, m, n, d); -#elif __SSE2__ - pv_gemm_bf16s_sse2(O, P, V, m, n, d); -#else - pv_gemm_bf16s_scalar(O, P, V, m, n, d); -#endif -} - -static inline void sdpa_decode_bf16s_dispatch(float* out, const float* q, - const unsigned short* K, const unsigned short* V, const float* mask, - int n, int d, int out_d, float scale) -{ - sdpa_decode_bf16s(out, q, K, V, mask, n, d, out_d, scale); -} - -static inline void sdpa_decode_chunk_bf16s_dispatch( - float* out, float* m_out, float* l_out, - const float* q, const unsigned short* K, const unsigned short* V, const float* mask, - int n_start, int n_end, int d, int out_d, float scale) -{ - sdpa_decode_chunk_bf16s(out, m_out, l_out, q, K, V, mask, n_start, n_end, d, out_d, scale); -} - -static inline void sdpa_decode_reduce_bf16s_dispatch( - float* out, int out_d, - const float* partials, int num_chunks, int partial_stride) -{ - sdpa_decode_reduce_bf16s(out, out_d, partials, num_chunks, partial_stride); -} #endif // SDPA_X86_BF16S_H diff --git a/src/layer/x86/sdpa_x86_int8.h b/src/layer/x86/sdpa_x86_int8.h index 4b848c12d823..e30496424098 100644 --- a/src/layer/x86/sdpa_x86_int8.h +++ b/src/layer/x86/sdpa_x86_int8.h @@ -1,3 +1,6 @@ +#ifndef SDPA_X86_INT8_H +#define SDPA_X86_INT8_H + static inline signed char float2int8(float v) { int int32 = static_cast(roundf(v)); @@ -3324,3 +3327,5 @@ static void pv_float_int8_gemm_tile(float* O, const float* P, #if __SSE2__ && !__SSE4_1__ #undef _mm_cvtepi8_epi32 #endif + +#endif // SDPA_X86_INT8_H From d14dea851d8d8f5b3d79040a22ae2b9390b9baaa Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Tue, 5 May 2026 01:42:48 +0800 Subject: [PATCH 42/53] perf(sdpa_x86): optimize MQA prefill for small seqlen - Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead - Change large_dim threshold to embed_dim>512 && src_seqlen>16 so that MQA configs with seqlen=16 use the !large_dim batched path - Eliminate Q memcpy in large_dim path by passing query_head.row() directly - Always allocate q_batch with full BLOCK_M*num_heads_per_group size --- src/layer/x86/sdpa_x86.cpp | 122 ++++++++++++++++++------------------- 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index c1827835cea3..76b56d77bc10 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1282,7 +1282,7 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { - if (D >= 512 && j + 8 <= n) + if (D >= 512 && n >= 256 && j + 8 <= n) { _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); @@ -1390,7 +1390,7 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { - if (D >= 512 && j + 8 <= n) + if (D >= 512 && n >= 256 && j + 8 <= n) { _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); @@ -1584,7 +1584,7 @@ template<> void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx512<4096, 8, 2>(S, Q, K, m, n, scale); } template @@ -1612,7 +1612,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { - if (d >= 512 && j + 4 < n) + if (d >= 512 && n >= 256 && j + 4 < n) _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) @@ -1642,7 +1642,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { - if (d >= 512 && j + 4 < n) + if (d >= 512 && n >= 256 && j + 4 < n) _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 pvec = _mm512_set1_ps(pptr[j]); for (int vi = 0; vi < VEC_PER_UNROLL; vi++) @@ -3184,8 +3184,8 @@ static int sdpa_forward_prefill( const int BLOCK_N = 128; Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - const bool large_dim = embed_dim > 512; - Mat q_batch(embed_dim, large_dim ? BLOCK_M : BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + const bool large_dim = embed_dim > 512 && src_seqlen > 16; + Mat q_batch(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); if (s_vec.empty() || o_accum.empty() || q_batch.empty()) return -100; @@ -3683,11 +3683,7 @@ static int sdpa_forward_prefill( int q = g * num_heads_per_group + hq; const Mat query_head = query_ref.channel(q); - float* q_dst = q_batch_thread.row(0); - for (int i = 0; i < block_m; i++) - { - memcpy(q_dst + i * embed_dim, query_head.row(m_start + i), embed_dim * sizeof(float)); - } + const float* q_ptr = query_head.row(m_start); for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) { @@ -3704,7 +3700,7 @@ static int sdpa_forward_prefill( if (ncnn::cpu_support_x86_avx512_bf16()) { qk_gemm_bf16s_avx512bf16(s_head, - q_dst, + q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); } @@ -3712,23 +3708,23 @@ static int sdpa_forward_prefill( #endif { qk_gemm_bf16s_avx512_kernel(s_head, - q_dst, + q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); } #elif __AVX__ qk_gemm_bf16s_avx_kernel(s_head, - q_dst, + q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #elif __SSE2__ qk_gemm_bf16s_sse2_kernel(s_head, - q_dst, + q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #else qk_gemm_bf16s_scalar_kernel(s_head, - q_dst, + q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #endif @@ -3740,167 +3736,167 @@ static int sdpa_forward_prefill( { case 128: #if __AVX512F__ - qk_gemm_specialized_avx512<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<128>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 64: #if __AVX512F__ - qk_gemm_specialized_avx512<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<64>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 512: #if __AVX512F__ - qk_gemm_specialized_avx512<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<512>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 256: #if __AVX512F__ - qk_gemm_specialized_avx512<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<256>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 32: #if __AVX512F__ - qk_gemm_specialized_avx512<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<32>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 80: #if __AVX512F__ - qk_gemm_specialized_avx512<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<80>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 96: #if __AVX512F__ - qk_gemm_specialized_avx512<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<96>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 160: #if __AVX512F__ - qk_gemm_specialized_avx512<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<160>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 1024: #if __AVX512F__ - qk_gemm_specialized_avx512<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<1024>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 2048: #if __AVX512F__ - qk_gemm_specialized_avx512<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<2048>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 4096: #if __AVX512F__ - qk_gemm_specialized_avx512<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<4096>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 768: #if __AVX512F__ - qk_gemm_specialized_avx512<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<768>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 1536: #if __AVX512F__ - qk_gemm_specialized_avx512<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<1536>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif case 3072: #if __AVX512F__ - qk_gemm_specialized_avx512<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<3072>(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #endif default: #if __AVX512F__ - qk_gemm_avx512(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #elif __AVX__ - qk_gemm_avx(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #elif __SSE2__ - qk_gemm_sse2(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #else - qk_gemm_scalar(s_head, q_dst, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); #endif break; } From 2b2965a0ea88a75e2d987a8c3d5382f53427b9fe Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Tue, 5 May 2026 08:59:43 +0800 Subject: [PATCH 43/53] perf(sdpa_x86): fix MQA/GQA prefill regressions and optimize small seqlen - Revert qk_gemm_specialized_avx512<4096> tiling from 8,2 back to 2,2 to fix massive register spilling in !large_dim batched path (acc[8][4] = 32 ZMMs). - Add qk_gemm_specialized_avx512_large_m<4096> (2,2) for large_m path, but switch large_dim D=4096 QK GEMM to generic qk_gemm_avx512 (M=8,N=2) which has better ILP for small m/n. - Change large_dim threshold to embed_dim>512 && src_seqlen>16 so MQA seqlen=16 uses the faster batched !large_dim path. - Add n>=256 guard to QK/PV GEMM prefetch to avoid small-n overhead. - Disable outer OpenMP when num_group*num_m_tiles==1 to avoid parallel region overhead for single-tile configs. - Restructure large_dim path as N-outer/head-inner to reuse K/V tiles across heads, with inner OpenMP over heads when only one outer tile exists. Fixes: groups=4 seqlen=32: ~51ms -> ~8ms (-84%, restored correct s_head striding) groups=1 seqlen=16: ~15ms -> ~5ms (-66%, uses batched path) groups=32 seqlen=32: ~6ms -> ~6ms (unchanged) Remaining gap vs baseline: groups=1 seqlen=32: +63% (8.3ms vs 5.1ms, inherent in post-refactor path) groups=1 seqlen=64: +34% (30.2ms vs 22.6ms) --- src/layer/x86/sdpa_x86.cpp | 1185 ++++++++++++++++++++++++------------ 1 file changed, 803 insertions(+), 382 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 76b56d77bc10..93f95a758b64 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1584,7 +1584,18 @@ template<> void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<4096, 8, 2>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); +} + +template +static inline void qk_gemm_specialized_avx512_large_m(float* S, const float* Q, const float* K, + int m, int n, float scale) {} + +template<> +void qk_gemm_specialized_avx512_large_m<4096>(float* S, const float* Q, const float* K, + int m, int n, float scale) +{ + qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); } template @@ -3199,7 +3210,7 @@ static int sdpa_forward_prefill( if (m_state.empty() || l_state.empty()) return -100; - #pragma omp parallel for num_threads(opt.num_threads) + #pragma omp parallel for num_threads(opt.num_threads) if(num_group * num_m_tiles > 1) for (int idx = 0; idx < num_group * num_m_tiles; idx++) { int g = idx / num_m_tiles; @@ -3677,415 +3688,825 @@ static int sdpa_forward_prefill( } else { - // large_dim: copy Q per head once, then loop over N-tiles - for (int hq = 0; hq < num_heads_per_group; hq++) + // large_dim: N-outer loop, head-inner loop to reuse K/V tiles across heads + for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) { - int q = g * num_heads_per_group + hq; - const Mat query_head = query_ref.channel(q); + int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; + int block_n2 = n_end2 - n_start2; - const float* q_ptr = query_head.row(m_start); - - for (int n_start2 = 0; n_start2 < dst_seqlen; n_start2 += BLOCK_N) + if (num_group * num_m_tiles == 1) { - int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; - int block_n2 = n_end2 - n_start2; - - float* s_head = s_ptr; - -#if NCNN_BF16 - if (use_bf16_path) + #pragma omp parallel for num_threads(opt.num_threads) if(opt.num_threads > 1) + for (int hq = 0; hq < num_heads_per_group; hq++) { -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + + #if NCNN_BF16 + if (use_bf16_path) { - qk_gemm_bf16s_avx512bf16(s_head, + #if __AVX512F__ + #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else + #endif + { + qk_gemm_bf16s_avx512_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + #elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - } - else -#endif - { - qk_gemm_bf16s_avx512_kernel(s_head, + #elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + #else + qk_gemm_bf16s_scalar_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #endif } -#elif __AVX__ - qk_gemm_bf16s_avx_kernel(s_head, - q_ptr, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); -#elif __SSE2__ - qk_gemm_bf16s_sse2_kernel(s_head, - q_ptr, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); -#else - qk_gemm_bf16s_scalar_kernel(s_head, - q_ptr, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); -#endif - } - else -#endif + else + #endif + { + switch (embed_dim) { - switch (embed_dim) - { - case 128: -#if __AVX512F__ - qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 64: -#if __AVX512F__ - qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 512: -#if __AVX512F__ - qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 256: -#if __AVX512F__ - qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 32: -#if __AVX512F__ - qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 80: -#if __AVX512F__ - qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 96: -#if __AVX512F__ - qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 160: -#if __AVX512F__ - qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 1024: -#if __AVX512F__ - qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 2048: -#if __AVX512F__ - qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 4096: -#if __AVX512F__ - qk_gemm_specialized_avx512<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 768: -#if __AVX512F__ - qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 1536: -#if __AVX512F__ - qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - case 3072: -#if __AVX512F__ - qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __AVX__ - qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#elif __SSE2__ - qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; -#endif - default: -#if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); -#elif __AVX__ - qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); -#elif __SSE2__ - qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); -#else - qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); -#endif - break; - } + case 128: + #if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 64: + #if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 512: + #if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 256: + #if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 32: + #if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 80: + #if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 96: + #if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 160: + #if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 1024: + #if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 2048: + #if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 4096: + #if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #endif + case 768: + #if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 1536: + #if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 3072: + #if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + default: + #if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #endif + break; } - - if (attn_mask && mask_data[hq]) - { + } + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; for (int i = 0; i < block_m; i++) { - const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; - float* sptr = s_head + i * block_n2; - decode_mask_vec(sptr, mptr, block_n2); + m_old[i] = m_vec[i]; } - } - - float* m_vec = m_state_tile.row(hq); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); - - float m_old[BLOCK_M]; - float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (m_old[i] != m_vec[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + + #if NCNN_BF16 + if (use_bf16_path) + { + #if __AVX512F__ + #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + else + #endif + { + pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + #elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + #elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + #else + pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + #endif + } + else + #endif + { + switch (out_embed_dim) { - m_old[i] = m_vec[i]; + case 128: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #endif + case 64: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #endif + case 256: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #endif + case 512: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #endif + case 1024: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #endif + case 2048: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #endif + case 4096: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #endif + case 768: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #endif + case 1536: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #endif + case 3072: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #endif + default: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #endif + break; } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); - - for (int i = 0; i < block_m; i++) - { - if (m_old[i] != m_vec[i]) - { - vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } - -#if NCNN_BF16 - if (use_bf16_path) + } + else + { + for (int hq = 0; hq < num_heads_per_group; hq++) { -#if __AVX512F__ -#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ - if (ncnn::cpu_support_x86_avx512_bf16()) + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + + #if NCNN_BF16 + if (use_bf16_path) { - pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + #if __AVX512F__ + #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + qk_gemm_bf16s_avx512bf16(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + else + #endif + { + qk_gemm_bf16s_avx512_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + } + #elif __AVX__ + qk_gemm_bf16s_avx_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + #elif __SSE2__ + qk_gemm_bf16s_sse2_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + #else + qk_gemm_bf16s_scalar_kernel(s_head, + q_ptr, + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); + #endif } else -#endif + #endif + { + switch (embed_dim) + { + case 128: + #if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 64: + #if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 512: + #if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 256: + #if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 32: + #if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 80: + #if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 96: + #if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 160: + #if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 1024: + #if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 2048: + #if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 4096: + #if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; + #endif + case 768: + #if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 1536: + #if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + case 3072: + #if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; + #endif + default: + #if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); + #endif + break; + } + } + + if (attn_mask && mask_data[hq]) + { + for (int i = 0; i < block_m; i++) + { + const float* mptr = mask_data[hq] + (m_start + i) * mask_stride[hq] + n_start2; + float* sptr = s_head + i * block_n2; + decode_mask_vec(sptr, mptr, block_n2); + } + } + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float m_old[BLOCK_M]; + float scale_factors[BLOCK_M]; + for (int i = 0; i < block_m; i++) { - pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + m_old[i] = m_vec[i]; + } + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (m_old[i] != m_vec[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + + #if NCNN_BF16 + if (use_bf16_path) + { + #if __AVX512F__ + #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + if (ncnn::cpu_support_x86_avx512_bf16()) + { + pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + else + #endif + { + pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + } + #elif __AVX__ + pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + #elif __SSE2__ + pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), + block_m, block_n2, out_embed_dim); + #else + pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #endif } -#elif __AVX__ - pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); -#elif __SSE2__ - pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); -#else - pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); -#endif - } - else -#endif + else + #endif + { + switch (out_embed_dim) { - switch (out_embed_dim) - { - case 128: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#elif __AVX__ - pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#endif - case 64: -#if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#elif __AVX__ - pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#elif __SSE2__ - pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; -#endif - case 256: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; -#endif - case 512: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; -#endif - case 1024: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; -#endif - case 2048: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; -#endif - case 4096: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; -#endif - case 768: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; -#endif - case 1536: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; -#endif - case 3072: -#if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; -#elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; -#elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; -#endif - default: -#if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); -#elif __AVX__ - pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); -#elif __SSE2__ - pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); -#else - pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); -#endif - break; - } + case 128: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #endif + case 64: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; + #endif + case 256: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; + #endif + case 512: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; + #endif + case 1024: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; + #endif + case 2048: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; + #endif + case 4096: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; + #endif + case 768: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; + #endif + case 1536: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; + #endif + case 3072: + #if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; + #endif + default: + #if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); + #endif + break; + } + } } } } From f39f525859d7f750f8cc9818d5b734e0e88ac7c7 Mon Sep 17 00:00:00 2001 From: futz12 <56149058+futz12@users.noreply.github.com> Date: Tue, 5 May 2026 01:07:03 +0000 Subject: [PATCH 44/53] apply code-format changes --- src/layer/x86/sdpa_x86.cpp | 2071 ++++++++++++++++---------------- src/layer/x86/sdpa_x86_bf16s.h | 33 +- 2 files changed, 1050 insertions(+), 1054 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 93f95a758b64..41ccf8783b5e 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1589,11 +1589,13 @@ void qk_gemm_specialized_avx512<4096>(float* S, const float* Q, const float* K, template static inline void qk_gemm_specialized_avx512_large_m(float* S, const float* Q, const float* K, - int m, int n, float scale) {} + int m, int n, float scale) +{ +} template<> void qk_gemm_specialized_avx512_large_m<4096>(float* S, const float* Q, const float* K, - int m, int n, float scale) + int m, int n, float scale) { qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); } @@ -3210,7 +3212,7 @@ static int sdpa_forward_prefill( if (m_state.empty() || l_state.empty()) return -100; - #pragma omp parallel for num_threads(opt.num_threads) if(num_group * num_m_tiles > 1) + #pragma omp parallel for num_threads(opt.num_threads) if (num_group * num_m_tiles > 1) for (int idx = 0; idx < num_group * num_m_tiles; idx++) { int g = idx / num_m_tiles; @@ -3696,20 +3698,20 @@ static int sdpa_forward_prefill( if (num_group * num_m_tiles == 1) { - #pragma omp parallel for num_threads(opt.num_threads) if(opt.num_threads > 1) + #pragma omp parallel for num_threads(opt.num_threads) if (opt.num_threads > 1) for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; const Mat query_head = query_ref.channel(q); - + const float* q_ptr = query_head.row(m_start); float* s_head = s_ptr + hq * block_m * block_n2; - - #if NCNN_BF16 + +#if NCNN_BF16 if (use_bf16_path) { - #if __AVX512F__ - #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { qk_gemm_bf16s_avx512bf16(s_head, @@ -3718,203 +3720,203 @@ static int sdpa_forward_prefill( block_m, block_n2, embed_dim, _scale); } else - #endif +#endif { qk_gemm_bf16s_avx512_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); } - #elif __AVX__ +#elif __AVX__ qk_gemm_bf16s_avx_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __SSE2__ +#elif __SSE2__ qk_gemm_bf16s_sse2_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #else +#else qk_gemm_bf16s_scalar_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #endif +#endif } else - #endif +#endif { - switch (embed_dim) - { - case 128: - #if __AVX512F__ - qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 64: - #if __AVX512F__ - qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 512: - #if __AVX512F__ - qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 256: - #if __AVX512F__ - qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 32: - #if __AVX512F__ - qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 80: - #if __AVX512F__ - qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 96: - #if __AVX512F__ - qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 160: - #if __AVX512F__ - qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 1024: - #if __AVX512F__ - qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 2048: - #if __AVX512F__ - qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 4096: - #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #elif __AVX__ - qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #elif __SSE2__ - qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #endif - case 768: - #if __AVX512F__ - qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 1536: - #if __AVX512F__ - qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 3072: - #if __AVX512F__ - qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - default: - #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __AVX__ - qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __SSE2__ - qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #else - qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #endif - break; - } + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#endif + break; + } } - + if (attn_mask && mask_data[hq]) { for (int i = 0; i < block_m; i++) @@ -3924,11 +3926,11 @@ static int sdpa_forward_prefill( decode_mask_vec(sptr, mptr, block_n2); } } - + float* m_vec = m_state_tile.row(hq); float* l_vec = l_state_tile.row(hq); float* o_ptr = o_accum_thread.row(hq * block_m); - + float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; for (int i = 0; i < block_m; i++) @@ -3936,7 +3938,7 @@ static int sdpa_forward_prefill( m_old[i] = m_vec[i]; } softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); - + for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) @@ -3944,161 +3946,161 @@ static int sdpa_forward_prefill( vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } - - #if NCNN_BF16 + +#if NCNN_BF16 if (use_bf16_path) { - #if __AVX512F__ - #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); } else - #endif +#endif { pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); } - #elif __AVX__ +#elif __AVX__ pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __SSE2__ +#elif __SSE2__ pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #else +#else pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #endif +#endif } else - #endif +#endif { - switch (out_embed_dim) - { - case 128: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __AVX__ - pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #endif - case 64: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __AVX__ - pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __SSE2__ - pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #endif - case 256: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #endif - case 512: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #endif - case 1024: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #endif - case 2048: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #endif - case 4096: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #endif - case 768: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #endif - case 1536: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #endif - case 3072: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #endif - default: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __AVX__ - pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __SSE2__ - pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #else - pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #endif - break; - } + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#endif + break; + } } } } @@ -4108,15 +4110,15 @@ static int sdpa_forward_prefill( { int q = g * num_heads_per_group + hq; const Mat query_head = query_ref.channel(q); - + const float* q_ptr = query_head.row(m_start); float* s_head = s_ptr + hq * block_m * block_n2; - - #if NCNN_BF16 + +#if NCNN_BF16 if (use_bf16_path) { - #if __AVX512F__ - #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { qk_gemm_bf16s_avx512bf16(s_head, @@ -4125,203 +4127,203 @@ static int sdpa_forward_prefill( block_m, block_n2, embed_dim, _scale); } else - #endif +#endif { qk_gemm_bf16s_avx512_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); } - #elif __AVX__ +#elif __AVX__ qk_gemm_bf16s_avx_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __SSE2__ +#elif __SSE2__ qk_gemm_bf16s_sse2_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #else +#else qk_gemm_bf16s_scalar_kernel(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #endif +#endif } else - #endif +#endif { - switch (embed_dim) - { - case 128: - #if __AVX512F__ - qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 64: - #if __AVX512F__ - qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 512: - #if __AVX512F__ - qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 256: - #if __AVX512F__ - qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 32: - #if __AVX512F__ - qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 80: - #if __AVX512F__ - qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 96: - #if __AVX512F__ - qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 160: - #if __AVX512F__ - qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 1024: - #if __AVX512F__ - qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 2048: - #if __AVX512F__ - qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 4096: - #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #elif __AVX__ - qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #elif __SSE2__ - qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); - break; - #endif - case 768: - #if __AVX512F__ - qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 1536: - #if __AVX512F__ - qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - case 3072: - #if __AVX512F__ - qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __AVX__ - qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #elif __SSE2__ - qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); - break; - #endif - default: - #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __AVX__ - qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #elif __SSE2__ - qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #else - qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); - #endif - break; - } + switch (embed_dim) + { + case 128: +#if __AVX512F__ + qk_gemm_specialized_avx512<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<128>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 64: +#if __AVX512F__ + qk_gemm_specialized_avx512<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<64>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 512: +#if __AVX512F__ + qk_gemm_specialized_avx512<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<512>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 256: +#if __AVX512F__ + qk_gemm_specialized_avx512<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<256>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 32: +#if __AVX512F__ + qk_gemm_specialized_avx512<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<32>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 80: +#if __AVX512F__ + qk_gemm_specialized_avx512<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<80>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 96: +#if __AVX512F__ + qk_gemm_specialized_avx512<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<96>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 160: +#if __AVX512F__ + qk_gemm_specialized_avx512<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<160>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1024: +#if __AVX512F__ + qk_gemm_specialized_avx512<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1024>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 2048: +#if __AVX512F__ + qk_gemm_specialized_avx512<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<2048>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 4096: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + break; +#endif + case 768: +#if __AVX512F__ + qk_gemm_specialized_avx512<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<768>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 1536: +#if __AVX512F__ + qk_gemm_specialized_avx512<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<1536>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + case 3072: +#if __AVX512F__ + qk_gemm_specialized_avx512<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __AVX__ + qk_gemm_specialized_avx<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#elif __SSE2__ + qk_gemm_specialized_sse2<3072>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + break; +#endif + default: +#if __AVX512F__ + qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __AVX__ + qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#elif __SSE2__ + qk_gemm_sse2(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#else + qk_gemm_scalar(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, embed_dim, _scale); +#endif + break; + } } - + if (attn_mask && mask_data[hq]) { for (int i = 0; i < block_m; i++) @@ -4331,11 +4333,11 @@ static int sdpa_forward_prefill( decode_mask_vec(sptr, mptr, block_n2); } } - + float* m_vec = m_state_tile.row(hq); float* l_vec = l_state_tile.row(hq); float* o_ptr = o_accum_thread.row(hq * block_m); - + float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; for (int i = 0; i < block_m; i++) @@ -4343,7 +4345,7 @@ static int sdpa_forward_prefill( m_old[i] = m_vec[i]; } softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); - + for (int i = 0; i < block_m; i++) { if (m_old[i] != m_vec[i]) @@ -4351,181 +4353,181 @@ static int sdpa_forward_prefill( vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } } - - #if NCNN_BF16 + +#if NCNN_BF16 if (use_bf16_path) { - #if __AVX512F__ - #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); } else - #endif +#endif { pv_gemm_bf16s_avx512_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); } - #elif __AVX__ +#elif __AVX__ pv_gemm_bf16s_avx_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __SSE2__ +#elif __SSE2__ pv_gemm_bf16s_sse2_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #else +#else pv_gemm_bf16s_scalar_kernel(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #endif +#endif } else - #endif +#endif { - switch (out_embed_dim) - { - case 128: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __AVX__ - pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #endif - case 64: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __AVX__ - pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #elif __SSE2__ - pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); - break; - #endif - case 256: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); - break; - #endif - case 512: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); - break; - #endif - case 1024: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); - break; - #endif - case 2048: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); - break; - #endif - case 4096: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); - break; - #endif - case 768: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); - break; - #endif - case 1536: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); - break; - #endif - case 3072: - #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #elif __AVX__ - pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); - break; - #endif - default: - #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __AVX__ - pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #elif __SSE2__ - pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #else - pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); - #endif - break; - } - } - } - } - } - } - } - - // Normalize all Q heads for this M tile and write back to top_blob - for (int hq = 0; hq < num_heads_per_group; hq++) - { - int q = g * num_heads_per_group + hq; - Mat top_blob_head = top_blob.channel(q); - float* l_vec = l_state_tile.row(hq); - float* o_ptr = o_accum_thread.row(hq * block_m); - - for (int i = 0; i < block_m; i++) - { - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - int k = 0; + switch (out_embed_dim) + { + case 128: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 64: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __AVX__ + pv_gemm_avx<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#elif __SSE2__ + pv_gemm_sse2<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2); + break; +#endif + case 256: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 256); + break; +#endif + case 512: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 512); + break; +#endif + case 1024: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1024); + break; +#endif + case 2048: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 2048); + break; +#endif + case 4096: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + break; +#endif + case 768: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 768); + break; +#endif + case 1536: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 1536); + break; +#endif + case 3072: +#if __AVX512F__ + pv_gemm_avx512<2, 128>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __AVX__ + pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#elif __SSE2__ + pv_gemm_sse2<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 3072); + break; +#endif + default: +#if __AVX512F__ + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __AVX__ + pv_gemm_avx<2, 32>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#elif __SSE2__ + pv_gemm_sse2<2, 16>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#else + pv_gemm_scalar(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, out_embed_dim); +#endif + break; + } + } + } + } + } + } + } + + // Normalize all Q heads for this M tile and write back to top_blob + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + Mat top_blob_head = top_blob.channel(q); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + for (int i = 0; i < block_m; i++) + { + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; #if __AVX512F__ __m512 vinv_l = _mm512_set1_ps(inv_l); for (; k + 15 < out_embed_dim; k += 16) @@ -4568,12 +4570,11 @@ static int sdpa_forward_prefill( return 0; } - static int sdpa_quantize_key_value_int8_x86(const Mat& key, const Mat& value, - Mat& key_int8, Mat& key_scales, Mat& value_int8, Mat& value_scales, - int num_group, int dst_seqlen, int embed_dim, int out_embed_dim, int v_num_blocks, - bool cache_valid, int past_seqlen, bool kv_cache, - const Option& opt) + Mat& key_int8, Mat& key_scales, Mat& value_int8, Mat& value_scales, + int num_group, int dst_seqlen, int embed_dim, int out_embed_dim, int v_num_blocks, + bool cache_valid, int past_seqlen, bool kv_cache, + const Option& opt) { #pragma omp parallel for num_threads(opt.num_threads) for (int g = 0; g < num_group; g++) @@ -4789,359 +4790,359 @@ static int sdpa_prefill_int8_x86( { const int BLOCK_M = 64; const int BLOCK_N = 128; -#pragma omp parallel for num_threads(opt.num_threads) -for (int q = 0; q < num_heads; q++) -{ - const Mat query_head = query.channel(q); - const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); - const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); - const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); - const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); - Mat top_blob_head = top_blob.channel(q); - - Mat mask_head; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - } - - Mat o_accum_head = o_accum.channel(get_omp_thread_num()); - float* s_vec_ptr = s_vec.row(get_omp_thread_num()); - float* p_vec_ptr = p_vec.row(get_omp_thread_num()); - Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); - Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); - - for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; - int block_m = m_end - m_start; + const Mat query_head = query.channel(q); + const Mat key_int8_head = key_int8.channel(q / num_heads_per_group); + const Mat key_scales_head = key_scales.channel(q / num_heads_per_group); + const Mat value_int8_head = value_int8.channel(q / num_heads_per_group); + const Mat value_scales_head = value_scales.channel(q / num_heads_per_group); + Mat top_blob_head = top_blob.channel(q); - for (int i = 0; i < block_m; i++) + Mat mask_head; + if (attn_mask) { - dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); - q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; + const Mat& maskm = attn_mask_ref; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } } - for (int i = 0; i < block_m; i++) - { - float* optr = o_accum_head.row(i); - vec_zero(optr, out_embed_dim); - } + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr = s_vec.row(get_omp_thread_num()); + float* p_vec_ptr = p_vec.row(get_omp_thread_num()); + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); - float m_vec[BLOCK_M]; - float l_vec[BLOCK_M]; - for (int i = 0; i < block_m; i++) + for (int m_start = 0; m_start < src_seqlen; m_start += BLOCK_M) { - m_vec[i] = -FLT_MAX; - l_vec[i] = 0.f; - } + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) - { - int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; - int block_n = n_end - n_start; - - if (block_m == 1) + for (int i = 0; i < block_m; i++) { - qk_int8_gemm_row(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0)[0], - key_scales_head.row(n_start), - block_n, embed_dim, _scale); + dynamic_quantize_rowwise(query_head.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; } - else + + for (int i = 0; i < block_m; i++) { - qk_int8_gemm_tiled(s_vec_ptr, - q_int8_tile_head.row(0), - key_int8_head.row(n_start), - q_scales_tile_head.row(0), - key_scales_head.row(n_start), - block_m, block_n, embed_dim, _scale); + float* optr = o_accum_head.row(i); + vec_zero(optr, out_embed_dim); } + float m_vec[BLOCK_M]; + float l_vec[BLOCK_M]; for (int i = 0; i < block_m; i++) { - float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); - float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); + m_vec[i] = -FLT_MAX; + l_vec[i] = 0.f; + } - if (attn_mask) - decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(m_vec[i]); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); - if (j < block_n) + if (block_m == 1) + { + qk_int8_gemm_row(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0)[0], + key_scales_head.row(n_start), + block_n, embed_dim, _scale); + } + else { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); + qk_int8_gemm_tiled(s_vec_ptr, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); } - float m_new = _mm512_comp_reduce_max_ps(vmax); + + for (int i = 0; i < block_m; i++) + { + float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); + float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); + + if (attn_mask) + decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); + +#if __AVX512F__ + __m512 vmax = _mm512_set1_ps(m_vec[i]); + int j = 0; + for (; j + 15 < block_n; j += 16) + vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); + } + float m_new = _mm512_comp_reduce_max_ps(vmax); #elif __AVX__ - __m256 vmax = _mm256_set1_ps(m_vec[i]); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); - float m_new = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); + __m256 vmax = _mm256_set1_ps(m_vec[i]); + int j = 0; + for (; j + 7 < block_n; j += 8) + vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); + float m_new = _mm256_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); #elif __SSE2__ - __m128 vmax = _mm_set1_ps(m_vec[i]); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); - float m_new = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); + __m128 vmax = _mm_set1_ps(m_vec[i]); + int j = 0; + for (; j + 3 < block_n; j += 4) + vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); + float m_new = _mm_reduce_max_ps(vmax); + for (; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); #else - float m_new = m_vec[i]; - for (int j = 0; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); + float m_new = m_vec[i]; + for (int j = 0; j < block_n; j++) + m_new = std::max(m_new, s_row[j]); #endif - float scale_factor = expf(m_vec[i] - m_new); - float l_new = l_vec[i] * scale_factor; + float scale_factor = expf(m_vec[i] - m_new); + float l_new = l_vec[i] * scale_factor; - float* optr = o_accum_head.row(i); - vec_scale(optr, scale_factor, out_embed_dim); + float* optr = o_accum_head.row(i); + vec_scale(optr, scale_factor, out_embed_dim); #if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(m_new); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); - _mm512_storeu_ps(p_row + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); - _mm512_mask_storeu_ps(p_row + j, mask, pvec); - vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); - } - l_new += _mm512_comp_reduce_add_ps(vsum); + __m512 vm_new = _mm512_set1_ps(m_new); + __m512 vsum = _mm512_setzero_ps(); + j = 0; + for (; j + 15 < block_n; j += 16) + { + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); + _mm512_storeu_ps(p_row + j, pvec); + vsum = _mm512_add_ps(vsum, pvec); + } + if (j < block_n) + { + __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); + __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); + _mm512_mask_storeu_ps(p_row + j, mask, pvec); + vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); + } + l_new += _mm512_comp_reduce_add_ps(vsum); #elif __AVX__ - __m256 vm_new = _mm256_set1_ps(m_new); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); - _mm256_storeu_ps(p_row + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - l_new += _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } + __m256 vm_new = _mm256_set1_ps(m_new); + __m256 vsum = _mm256_setzero_ps(); + j = 0; + for (; j + 7 < block_n; j += 8) + { + __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); + _mm256_storeu_ps(p_row + j, pvec); + vsum = _mm256_add_ps(vsum, pvec); + } + l_new += _mm256_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } #elif __SSE2__ - __m128 vm_new = _mm_set1_ps(m_new); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); - _mm_storeu_ps(p_row + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - l_new += _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } + __m128 vm_new = _mm_set1_ps(m_new); + __m128 vsum = _mm_setzero_ps(); + j = 0; + for (; j + 3 < block_n; j += 4) + { + __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); + _mm_storeu_ps(p_row + j, pvec); + vsum = _mm_add_ps(vsum, pvec); + } + l_new += _mm_reduce_add_ps(vsum); + for (; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } #else - for (int j = 0; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } + for (int j = 0; j < block_n; j++) + { + p_row[j] = expf(s_row[j] - m_new); + l_new += p_row[j]; + } #endif - m_vec[i] = m_new; - l_vec[i] = l_new; - } + m_vec[i] = m_new; + l_vec[i] = l_new; + } - if (kv_cache) - { - pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, - value_int8_head.row(n_start), - value_scales_head.row(n_start), - block_m, block_n, out_embed_dim); - } - else - { - switch (out_embed_dim) - { - case 128: + if (kv_cache) + { + pv_float_int8_gemm_tile(o_accum_head.row(0), p_vec_ptr, + value_int8_head.row(n_start), + value_scales_head.row(n_start), + block_m, block_n, out_embed_dim); + } + else + { + switch (out_embed_dim) + { + case 128: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #elif __AVX__ - pv_gemm_avx<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_avx<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #elif __SSE2__ - pv_gemm_sse2<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_sse2<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #endif - case 64: + case 64: #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #elif __AVX__ - pv_gemm_avx<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_avx<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #elif __SSE2__ - pv_gemm_sse2<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); - break; + pv_gemm_sse2<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n); + break; #endif - case 256: + case 256: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 256); + break; #endif - case 512: + case 512: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 512); + break; #endif - case 1024: + case 1024: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1024); + break; #endif - case 2048: + case 2048: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 2048); + break; #endif - case 4096: + case 4096: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + break; #endif - case 768: + case 768: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 768); + break; #endif - case 1536: + case 1536: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 1536); + break; #endif - case 3072: + case 3072: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); - break; + pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; #elif __AVX__ - pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); - break; + pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; #elif __SSE2__ - pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); - break; + pv_gemm_sse2<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 3072); + break; #endif - default: + default: #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); #elif __AVX__ - pv_gemm_avx<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_avx<2, 32>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); #elif __SSE2__ - pv_gemm_sse2<2, 16>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_sse2<2, 16>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); #else - pv_gemm_scalar(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); + pv_gemm_scalar(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, out_embed_dim); #endif - break; - } + break; + } + } } - } - for (int i = 0; i < block_m; i++) - { - float* optr = o_accum_head.row(i); - float* outptr = top_blob_head.row(m_start + i); - float inv_l = 1.f / l_vec[i]; - int k = 0; -#if __AVX512F__ - __m512 vinv_l = _mm512_set1_ps(inv_l); - for (; k + 15 < out_embed_dim; k += 16) - _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); - if (k < out_embed_dim) + for (int i = 0; i < block_m; i++) { - __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); - _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); - } + float* optr = o_accum_head.row(i); + float* outptr = top_blob_head.row(m_start + i); + float inv_l = 1.f / l_vec[i]; + int k = 0; +#if __AVX512F__ + __m512 vinv_l = _mm512_set1_ps(inv_l); + for (; k + 15 < out_embed_dim; k += 16) + _mm512_storeu_ps(outptr + k, _mm512_mul_ps(_mm512_loadu_ps(optr + k), vinv_l)); + if (k < out_embed_dim) + { + __mmask16 mask = (__mmask16)((1u << (out_embed_dim - k)) - 1); + _mm512_mask_storeu_ps(outptr + k, mask, _mm512_mul_ps(_mm512_maskz_loadu_ps(mask, optr + k), vinv_l)); + } #elif __AVX__ - __m256 vinv_l256 = _mm256_set1_ps(inv_l); - for (; k + 7 < out_embed_dim; k += 8) - _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); + __m256 vinv_l256 = _mm256_set1_ps(inv_l); + for (; k + 7 < out_embed_dim; k += 8) + _mm256_storeu_ps(outptr + k, _mm256_mul_ps(_mm256_loadu_ps(optr + k), vinv_l256)); #elif __SSE2__ - __m128 vinv_l128 = _mm_set1_ps(inv_l); - for (; k + 3 < out_embed_dim; k += 4) - _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); + __m128 vinv_l128 = _mm_set1_ps(inv_l); + for (; k + 3 < out_embed_dim; k += 4) + _mm_storeu_ps(outptr + k, _mm_mul_ps(_mm_loadu_ps(optr + k), vinv_l128)); #endif - for (; k < out_embed_dim; k++) - outptr[k] = optr[k] * inv_l; + for (; k < out_embed_dim; k++) + outptr[k] = optr[k] * inv_l; + } } } -} return 0; } @@ -5499,77 +5500,113 @@ static int sdpa_decode_x86( { const int BLOCK_N = 128; const bool use_split_kv = opt.num_threads > 1 && dst_seqlen >= BLOCK_N * 2; -if (use_split_kv) -{ - const int num_kv_chunks = opt.num_threads; - Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); - if (partials.empty()) - return -100; - - #pragma omp parallel for num_threads(opt.num_threads) - for (int task = 0; task < num_heads * num_kv_chunks; task++) + if (use_split_kv) { - int q = task / num_kv_chunks; - int chunk = task % num_kv_chunks; + const int num_kv_chunks = opt.num_threads; + Mat partials(2 + out_embed_dim, num_kv_chunks, num_heads, 4u, opt.workspace_allocator); + if (partials.empty()) + return -100; - int n_start = chunk * dst_seqlen / num_kv_chunks; - int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_heads * num_kv_chunks; task++) + { + int q = task / num_kv_chunks; + int chunk = task % num_kv_chunks; - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); + int n_start = chunk * dst_seqlen / num_kv_chunks; + int n_end = (chunk + 1 == num_kv_chunks) ? dst_seqlen : (chunk + 1) * dst_seqlen / num_kv_chunks; - const float* qptr = query_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); + const Mat query_head = query_ref.channel(q); - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else + const float* qptr = query_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) { - mask_head = maskm; + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); } - mask_ptr = mask_head.row(0); + + float* p = partials.channel(q).row(chunk); + sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, + n_start, n_end, embed_dim, out_embed_dim, _scale); } - float* p = partials.channel(q).row(chunk); - sdpa_decode_chunk(p + 2, p, p + 1, qptr, Kptr, Vptr, mask_ptr, - n_start, n_end, embed_dim, out_embed_dim, _scale); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) + { + Mat top_blob_head = top_blob.channel(q); + float* outptr = top_blob_head.row(0); + sdpa_decode_reduce(outptr, out_embed_dim, + partials.channel(q), num_kv_chunks, 2 + out_embed_dim); + } } - - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) + else { - Mat top_blob_head = top_blob.channel(q); - float* outptr = top_blob_head.row(0); - sdpa_decode_reduce(outptr, out_embed_dim, - partials.channel(q), num_kv_chunks, 2 + out_embed_dim); - } -} -else -{ - // Decode path: fused GEMV kernel for single-query attention - const bool group_parallel = num_group >= opt.num_threads; + // Decode path: fused GEMV kernel for single-query attention + const bool group_parallel = num_group >= opt.num_threads; - if (group_parallel) - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int g = 0; g < num_group; g++) + if (group_parallel) { - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); + #pragma omp parallel for num_threads(opt.num_threads) + for (int g = 0; g < num_group; g++) + { + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); - for (int hq = 0; hq < num_heads_per_group; hq++) + for (int hq = 0; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + Mat top_blob_head = top_blob.channel(q); + + const float* qptr = query_head.row(0); + float* outptr = top_blob_head.row(0); + const float* Kptr = key_head; + const float* Vptr = value_head; + + const float* mask_ptr = nullptr; + if (attn_mask) + { + const Mat& maskm = attn_mask_ref; + Mat mask_head; + if (maskm.dims == 3) + { + mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); + } + else + { + mask_head = maskm; + } + mask_ptr = mask_head.row(0); + } + + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + } + } + } + else + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_heads; q++) { - int q = g * num_heads_per_group + hq; + int g = q / num_heads_per_group; + const Mat key_head = key.channel(g); + const Mat value_head = value.channel(g); const Mat query_head = query_ref.channel(q); Mat top_blob_head = top_blob.channel(q); @@ -5594,46 +5631,10 @@ else mask_ptr = mask_head.row(0); } - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); + sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); } } } - else - { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < num_heads; q++) - { - int g = q / num_heads_per_group; - const Mat key_head = key.channel(g); - const Mat value_head = value.channel(g); - const Mat query_head = query_ref.channel(q); - Mat top_blob_head = top_blob.channel(q); - - const float* qptr = query_head.row(0); - float* outptr = top_blob_head.row(0); - const float* Kptr = key_head; - const float* Vptr = value_head; - - const float* mask_ptr = nullptr; - if (attn_mask) - { - const Mat& maskm = attn_mask_ref; - Mat mask_head; - if (maskm.dims == 3) - { - mask_head = maskm.c > 1 ? maskm.channel(q) : maskm.channel(0); - } - else - { - mask_head = maskm; - } - mask_ptr = mask_head.row(0); - } - - sdpa_decode(outptr, qptr, Kptr, Vptr, mask_ptr, dst_seqlen, embed_dim, out_embed_dim, _scale); - } - } -} return 0; } @@ -5797,8 +5798,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to } sdpa_quantize_key_value_int8_x86(key, value, key_int8, key_scales, value_int8, value_scales, - num_group, dst_seqlen, embed_dim, out_embed_dim, v_num_blocks, - cache_valid, past_seqlen, kv_cache, opt); + num_group, dst_seqlen, embed_dim, out_embed_dim, v_num_blocks, + cache_valid, past_seqlen, kv_cache, opt); Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); @@ -5812,9 +5813,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (src_seqlen == 1) { int ret = sdpa_decode_int8_x86(query, key_int8, key_scales, value_int8, value_scales, - top_blob, attn_mask_ref, opt, embed_dim, num_heads, num_group, out_embed_dim, - dst_seqlen, num_heads_per_group, _scale, attn_mask, - o_accum, s_vec, q_int8_tile, q_scales_tile); + top_blob, attn_mask_ref, opt, embed_dim, num_heads, num_group, out_embed_dim, + dst_seqlen, num_heads_per_group, _scale, attn_mask, + o_accum, s_vec, q_int8_tile, q_scales_tile); if (ret != 0) return ret; @@ -5835,9 +5836,9 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to else { int ret = sdpa_prefill_int8_x86(query, key_int8, key_scales, value_int8, value_scales, - top_blob, attn_mask_ref, opt, embed_dim, num_heads, out_embed_dim, src_seqlen, - dst_seqlen, num_heads_per_group, _scale, attn_mask, kv_cache, - o_accum, s_vec, p_vec, q_int8_tile, q_scales_tile, value); + top_blob, attn_mask_ref, opt, embed_dim, num_heads, out_embed_dim, src_seqlen, + dst_seqlen, num_heads_per_group, _scale, attn_mask, kv_cache, + o_accum, s_vec, p_vec, q_int8_tile, q_scales_tile, value); if (ret != 0) return ret; } @@ -5871,8 +5872,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to if (use_bf16_path) { int ret = sdpa_decode_bf16s_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, - embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, - num_heads_per_group, _scale, attn_mask); + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); if (ret != 0) return ret; @@ -5887,8 +5888,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to #endif // NCNN_BF16 int ret = sdpa_decode_x86(query_ref, key, value, top_blob, attn_mask_ref, opt, - embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, - num_heads_per_group, _scale, attn_mask); + embed_dim, num_heads, num_group, out_embed_dim, dst_seqlen, + num_heads_per_group, _scale, attn_mask); if (ret != 0) return ret; diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h index 60765374b34e..231b9f73bbe0 100644 --- a/src/layer/x86/sdpa_x86_bf16s.h +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -263,7 +263,6 @@ static inline void decode_qk_dot_bf16s_scalar_kernel(float* s, const float* q, c } } - // --------------------------------------------------------------------------- // decode_pv_gemv_bf16s : S(fp32) gemv V(bf16) -> out(fp32) // --------------------------------------------------------------------------- @@ -407,7 +406,6 @@ static inline void decode_pv_gemv_bf16s_scalar_kernel(float* out, const float* s } } - // --------------------------------------------------------------------------- // sdpa_decode_reduce_bf16s : reduce partial results from split-kv chunks // --------------------------------------------------------------------------- @@ -503,14 +501,13 @@ static inline void sdpa_decode_reduce_bf16s( } } - // --------------------------------------------------------------------------- // qk_gemm_bf16s : Q(fp32) x K^T(bf16) -> S(fp32) [prefill] // --------------------------------------------------------------------------- #if __AVX512F__ static void qk_gemm_bf16s_avx512_kernel(float* S, const float* Q, const unsigned short* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { int i = 0; for (; i + 8 <= m; i += 8) @@ -750,10 +747,9 @@ static void qk_gemm_bf16s_avx512_kernel(float* S, const float* Q, const unsigned } #endif // __AVX512F__ - #if __AVX__ static void qk_gemm_bf16s_avx_kernel(float* S, const float* Q, const unsigned short* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { int i = 0; for (; i + 6 <= m; i += 6) @@ -880,7 +876,7 @@ static void qk_gemm_bf16s_avx_kernel(float* S, const float* Q, const unsigned sh #if __SSE2__ static void qk_gemm_bf16s_sse2_kernel(float* S, const float* Q, const unsigned short* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { for (int i = 0; i < m; i++) { @@ -930,7 +926,7 @@ static void qk_gemm_bf16s_sse2_kernel(float* S, const float* Q, const unsigned s #endif // __SSE2__ static void qk_gemm_bf16s_scalar_kernel(float* S, const float* Q, const unsigned short* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { for (int i = 0; i < m; i++) { @@ -946,7 +942,6 @@ static void qk_gemm_bf16s_scalar_kernel(float* S, const float* Q, const unsigned } } - // --------------------------------------------------------------------------- // pv_gemm_bf16s : P(fp32) x V(bf16) -> O(fp32) [prefill] // --------------------------------------------------------------------------- @@ -1413,7 +1408,7 @@ static void decode_pv_gemv_bf16s_avx512bf16_kernel(float* out, const float* s, c } static void qk_gemm_bf16s_avx512bf16_kernel(float* S, const float* Q, const unsigned short* K, - int m, int n, int d, float scale) + int m, int n, int d, float scale) { unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * d * sizeof(unsigned short), 64); @@ -1718,8 +1713,10 @@ static void pv_gemm_bf16s_avx512bf16_kernel(float* O, const float* P, const unsi acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); } - _mm512_storeu_ps(op0 + 0, acc00); _mm512_storeu_ps(op1 + 0, acc10); - _mm512_storeu_ps(op0 + 16, acc01); _mm512_storeu_ps(op1 + 16, acc11); + _mm512_storeu_ps(op0 + 0, acc00); + _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); + _mm512_storeu_ps(op1 + 16, acc11); } for (; i < m; i++) { @@ -1822,7 +1819,7 @@ static void pv_gemm_bf16s_avx512bf16_kernel(float* O, const float* P, const unsi template static void qk_gemm_bf16s_avx512bf16_kernel_t(float* S, const float* Q, const unsigned short* K, - int m, int n, float scale) + int m, int n, float scale) { unsigned short* q_bf16 = (unsigned short*)_mm_malloc(m * D * sizeof(unsigned short), 64); @@ -2053,7 +2050,6 @@ static void qk_gemm_bf16s_avx512bf16_kernel_t(float* S, const float* Q, const un _mm_free(q_bf16); } - template static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n) { @@ -2137,8 +2133,10 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un acc10 = _mm512_dpbf16_ps(acc10, (__m512bh)p_ext1, (__m512bh)v0_ext); acc11 = _mm512_dpbf16_ps(acc11, (__m512bh)p_ext1, (__m512bh)v1_ext); } - _mm512_storeu_ps(op0 + 0, acc00); _mm512_storeu_ps(op1 + 0, acc10); - _mm512_storeu_ps(op0 + 16, acc01); _mm512_storeu_ps(op1 + 16, acc11); + _mm512_storeu_ps(op0 + 0, acc00); + _mm512_storeu_ps(op1 + 0, acc10); + _mm512_storeu_ps(op0 + 16, acc01); + _mm512_storeu_ps(op1 + 16, acc11); } for (; i < m; i++) { @@ -2241,7 +2239,4 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un #endif // __AVX512F__ && __AVX512BF16__ - - - #endif // SDPA_X86_BF16S_H From 155d7e47e3d24fda8a191d8fc64068f874e160cf Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 00:39:40 +0800 Subject: [PATCH 45/53] x86: improve sdpa prefill gqa large-dim path --- src/layer/x86/sdpa_x86.cpp | 2 +- tests/perf/perf_sdpa_prefill.cpp | 22 ++++++++++++++++++++++ tests/perf/perfutil.cpp | 18 +++++++++++++++++- 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 41ccf8783b5e..8185147a3ea2 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3197,7 +3197,7 @@ static int sdpa_forward_prefill( const int BLOCK_N = 128; Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - const bool large_dim = embed_dim > 512 && src_seqlen > 16; + const bool large_dim = embed_dim > 512 && (src_seqlen > 16 || num_heads_per_group > 1); Mat q_batch(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); if (s_vec.empty() || o_accum.empty() || q_batch.empty()) diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index 1977fbb9c39f..fae0eee19b53 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -3,9 +3,31 @@ #include "perfutil.h" +#include + +static bool match_env_int(const char* name, int value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return atoi(s) == value; +} + +static bool should_run_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) +{ + return match_env_int("NCNN_PERF_SDPA_EMBED", embed_dim) + && match_env_int("NCNN_PERF_SDPA_HEADS", num_heads) + && match_env_int("NCNN_PERF_SDPA_GROUPS", num_groups) + && match_env_int("NCNN_PERF_SDPA_SEQLEN", src_seqlen); +} + // prefill phase: larger src_seqlen, no kv_cache (past_seqlen=0) static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) { + if (!should_run_prefill(embed_dim, num_heads, num_groups, src_seqlen)) + return; + const int cur_seqlen = src_seqlen; // in prefill, cur_seqlen == src_seqlen const int out_embed_dim = embed_dim; diff --git a/tests/perf/perfutil.cpp b/tests/perf/perfutil.cpp index e381b0a512de..79c81b7d9841 100644 --- a/tests/perf/perfutil.cpp +++ b/tests/perf/perfutil.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #if NCNN_VULKAN @@ -743,6 +744,15 @@ static const PrecisionConfig s_configs[] = { }; static const int s_num_configs = sizeof(s_configs) / sizeof(s_configs[0]); +static bool should_run_dtype(const char* label) +{ + const char* s = getenv("NCNN_PERF_DTYPE"); + if (!s || !s[0]) + return true; + + return strcmp(s, label) == 0; +} + static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, const std::vector& weights, const std::vector& inputs, @@ -754,6 +764,9 @@ static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, int cpu_inner_loops = 0; for (int i = 0; i < s_num_configs; i++) { + if (!should_run_dtype(s_configs[i].label)) + continue; + ncnn::Option opt = make_perf_option(s_configs[i].fp16_ps, s_configs[i].fp16_arith, s_configs[i].bf16); PerfResult result; @@ -792,6 +805,9 @@ static void perf_layer_impl(const char* layer_type, const ncnn::ParamDict& pd, int gpu_inner_loops = 0; for (int i = 0; i < s_num_configs; i++) { + if (!should_run_dtype(s_configs[i].label)) + continue; + ncnn::Option opt = make_perf_option(s_configs[i].fp16_ps, s_configs[i].fp16_arith, s_configs[i].bf16); PerfResult result; @@ -870,7 +886,7 @@ void perf_layer_int8(const char* layer_type, const ncnn::ParamDict& pd, opt.use_int8_inference = true; PerfResult result; - int ret = perf_layer_cpu(layer_type, pd, weights, inputs, top_blob_count, opt, 0, result); + int ret = should_run_dtype("int8") ? perf_layer_cpu(layer_type, pd, weights, inputs, top_blob_count, opt, 0, result) : -1; if (ret == 0) { char full_tag[512]; From 4b1d090cb0b9bbc2d2c8ddf93165f4b7910a6611 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 00:50:04 +0800 Subject: [PATCH 46/53] x86: avoid sdpa prefill bf16 query roundtrip --- src/layer/x86/sdpa_x86.cpp | 41 +++-- src/layer/x86/sdpa_x86_avx512bf16.cpp | 5 + src/layer/x86/sdpa_x86_bf16s.h | 214 ++++++++++++++++++++++++++ tests/test_sdpa.cpp | 17 +- 4 files changed, 262 insertions(+), 15 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 8185147a3ea2..cded1579eafc 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3174,6 +3174,7 @@ static inline void pv_gemm_sse2(float* O, const float* P, const float* V, int m, static int sdpa_forward_prefill( const Mat& query_ref, + const Mat& query_bf16_ref, const Mat& attn_mask_ref, Mat& key, Mat& value, @@ -3193,6 +3194,7 @@ static int sdpa_forward_prefill( bool use_bf16_path) { (void)num_heads; + (void)query_bf16_ref; const int BLOCK_M = 64; const int BLOCK_N = 128; Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); @@ -3714,10 +3716,11 @@ static int sdpa_forward_prefill( #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { - qk_gemm_bf16s_avx512bf16(s_head, - q_ptr, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); + const Mat query_bf16_head = query_bf16_ref.channel(q); + qk_gemm_bf16s_avx512bf16_qbf16(s_head, + query_bf16_head.row(m_start), + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); } else #endif @@ -4121,10 +4124,11 @@ static int sdpa_forward_prefill( #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ if (ncnn::cpu_support_x86_avx512_bf16()) { - qk_gemm_bf16s_avx512bf16(s_head, - q_ptr, - key_head.row(n_start2), - block_m, block_n2, embed_dim, _scale); + const Mat query_bf16_head = query_bf16_ref.channel(q); + qk_gemm_bf16s_avx512bf16_qbf16(s_head, + query_bf16_head.row(m_start), + key_head.row(n_start2), + block_m, block_n2, embed_dim, _scale); } else #endif @@ -5724,13 +5728,24 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to #if NCNN_BF16 bool use_bf16_path = opt.use_bf16_storage && query.elembits() == 16; + bool use_qbf16_prefill = false; +#if __AVX512F__ +#if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + use_qbf16_prefill = use_bf16_path && src_seqlen > 1 && embed_dim > 512 + && (src_seqlen > 16 || num_heads_per_group > 1) + && ncnn::cpu_support_x86_avx512_bf16(); +#endif +#endif Mat query_fp32; Mat attn_mask_fp32; if (use_bf16_path) { - cast_bfloat16_to_float32(query, query_fp32, opt); - if (query_fp32.empty()) - return -100; + if (!use_qbf16_prefill) + { + cast_bfloat16_to_float32(query, query_fp32, opt); + if (query_fp32.empty()) + return -100; + } if (attn_mask && !attn_mask_blob.empty()) { cast_bfloat16_to_float32(attn_mask_blob, attn_mask_fp32, opt); @@ -5738,7 +5753,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return -100; } } - const Mat& query_ref = use_bf16_path ? query_fp32 : query; + const Mat& query_ref = use_bf16_path && !use_qbf16_prefill ? query_fp32 : query; const Mat& attn_mask_ref = (use_bf16_path && attn_mask) ? attn_mask_fp32 : attn_mask_blob; #else const Mat& query_ref = query; @@ -5902,7 +5917,7 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } - return sdpa_forward_prefill(query_ref, attn_mask_ref, key, value, top_blob, + return sdpa_forward_prefill(query_ref, query, attn_mask_ref, key, value, top_blob, top_blobs, opt, embed_dim, src_seqlen, num_heads, num_group, out_embed_dim, dst_seqlen, num_heads_per_group, _scale, kv_cache, attn_mask, diff --git a/src/layer/x86/sdpa_x86_avx512bf16.cpp b/src/layer/x86/sdpa_x86_avx512bf16.cpp index 298d68bcef07..6befdf52c33b 100644 --- a/src/layer/x86/sdpa_x86_avx512bf16.cpp +++ b/src/layer/x86/sdpa_x86_avx512bf16.cpp @@ -72,6 +72,11 @@ void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, } } +void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const unsigned short* K, int m, int n, int d, float scale) +{ + qk_gemm_bf16s_avx512bf16_qbf16_kernel(S, Q, K, m, n, d, scale); +} + void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d) { switch (d) diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h index 231b9f73bbe0..f08f2e1cdfd0 100644 --- a/src/layer/x86/sdpa_x86_bf16s.h +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -22,6 +22,7 @@ void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned short* K, int n_start, int block_n, int d, float scale); void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d); void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale); +void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const unsigned short* K, int m, int n, int d, float scale); void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d); #endif @@ -1631,6 +1632,219 @@ static void qk_gemm_bf16s_avx512bf16_kernel(float* S, const float* Q, const unsi _mm_free(q_bf16); } +static void qk_gemm_bf16s_avx512bf16_qbf16_kernel(float* S, const unsigned short* Q, const unsigned short* K, + int m, int n, int d, float scale) +{ + int i = 0; + for (; i + 8 <= m; i += 8) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[8][2]; + for (int mi = 0; mi < 8; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[8][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 8; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[8]; + for (int mi = 0; mi < 8; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 8; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[8] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 8; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 8; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + for (; i + 4 <= m; i += 4) + { + int j = 0; + for (; j + 2 <= n; j += 2) + { + const unsigned short* k0 = K + (j + 0) * d; + const unsigned short* k1 = K + (j + 1) * d; + + __m512 acc[4][2]; + for (int mi = 0; mi < 4; mi++) + { + acc[mi][0] = _mm512_setzero_ps(); + acc[mi][1] = _mm512_setzero_ps(); + } + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv0 = _mm512_loadu_si512((const __m512i*)(k0 + k)); + __m512i kv1 = _mm512_loadu_si512((const __m512i*)(k1 + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + } + if (k + 15 < d) + { + __m512i kv0 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k0 + k)), 0); + __m512i kv1 = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(k1 + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi][0] = _mm512_dpbf16_ps(acc[mi][0], (__m512bh)qv, (__m512bh)kv0); + acc[mi][1] = _mm512_dpbf16_ps(acc[mi][1], (__m512bh)qv, (__m512bh)kv1); + } + k += 16; + } + float tail_sum[4][2] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi][0] += qv * bfloat16_to_float32(k0[k]); + tail_sum[mi][1] += qv * bfloat16_to_float32(k1[k]); + } + } + + for (int mi = 0; mi < 4; mi++) + { + S[(i + mi) * n + j + 0] = (_mm512_comp_reduce_add_ps(acc[mi][0]) + tail_sum[mi][0]) * scale; + S[(i + mi) * n + j + 1] = (_mm512_comp_reduce_add_ps(acc[mi][1]) + tail_sum[mi][1]) * scale; + } + } + + for (; j < n; j++) + { + const unsigned short* kptr = K + j * d; + __m512 acc[4]; + for (int mi = 0; mi < 4; mi++) + acc[mi] = _mm512_setzero_ps(); + + int k = 0; + for (; k + 31 < d; k += 32) + { + __m512i kv = _mm512_loadu_si512((const __m512i*)(kptr + k)); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_loadu_si512((const __m512i*)(Q + (i + mi) * d + k)); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + } + if (k + 15 < d) + { + __m512i kv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(kptr + k)), 0); + for (int mi = 0; mi < 4; mi++) + { + __m512i qv = _mm512_inserti64x4(_mm512_setzero_si512(), _mm256_loadu_si256((const __m256i*)(Q + (i + mi) * d + k)), 0); + acc[mi] = _mm512_dpbf16_ps(acc[mi], (__m512bh)qv, (__m512bh)kv); + } + k += 16; + } + float tail_sum[4] = {}; + for (; k < d; k++) + { + for (int mi = 0; mi < 4; mi++) + { + float qv = bfloat16_to_float32(Q[(i + mi) * d + k]); + tail_sum[mi] += qv * bfloat16_to_float32(kptr[k]); + } + } + for (int mi = 0; mi < 4; mi++) + S[(i + mi) * n + j] = (_mm512_comp_reduce_add_ps(acc[mi]) + tail_sum[mi]) * scale; + } + } + + if (i < m) + { + for (int ii = i; ii < m; ii++) + for (int j = 0; j < n; j++) + { + float sum = 0.f; + for (int k = 0; k < d; k++) + sum += bfloat16_to_float32(Q[ii * d + k]) * bfloat16_to_float32(K[j * d + k]); + S[ii * n + j] = sum * scale; + } + } +} + static void pv_gemm_bf16s_avx512bf16_kernel(float* O, const float* P, const unsigned short* V, int m, int n, int d) { unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); diff --git a/tests/test_sdpa.cpp b/tests/test_sdpa.cpp index f25633b435ab..c043c6b731bd 100644 --- a/tests/test_sdpa.cpp +++ b/tests/test_sdpa.cpp @@ -3,6 +3,8 @@ #include "testutil.h" +#include + static int test_sdpa(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0.f) { const int src_seqlen = q.h; @@ -56,6 +58,17 @@ static int test_sdpa_0() || test_sdpa(RandomMat(12, 1, 4), RandomMat(12, 127, 4), RandomMat(55, 127, 4), 1, 1.f); } +static int test_sdpa_large_dim() +{ + if (!getenv("NCNN_TEST_SDPA_LARGE_DIM")) + return 0; + + return 0 + || test_sdpa(RandomMat(4096, 16, 32), RandomMat(4096, 16, 1), RandomMat(4096, 16, 1), 0, 1.f / 64.f) + || test_sdpa(RandomMat(4096, 16, 32), RandomMat(4096, 16, 4), RandomMat(4096, 16, 4), 0, 1.f / 64.f) + || test_sdpa(RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), 0, 1.f / 64.f); +} + #if NCNN_INT8 static int test_sdpa_int8(const ncnn::Mat& q, const ncnn::Mat& k, const ncnn::Mat& v, int attn_mask, float scale = 0.f) { @@ -111,8 +124,8 @@ int main() SRAND(7767517); #if NCNN_INT8 - return test_sdpa_0() || test_sdpa_1(); + return test_sdpa_0() || test_sdpa_1() || test_sdpa_large_dim(); #else - return test_sdpa_0(); + return test_sdpa_0() || test_sdpa_large_dim(); #endif } From 7818da31e7dce46a549f41c0d73211ec7eb6ea8c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 01:00:07 +0800 Subject: [PATCH 47/53] x86: avoid unused sdpa prefill q batch workspace --- src/layer/x86/sdpa_x86.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index cded1579eafc..2e10583acf05 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3200,9 +3200,11 @@ static int sdpa_forward_prefill( Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); const bool large_dim = embed_dim > 512 && (src_seqlen > 16 || num_heads_per_group > 1); - Mat q_batch(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); + Mat q_batch; + if (!large_dim) + q_batch.create(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - if (s_vec.empty() || o_accum.empty() || q_batch.empty()) + if (s_vec.empty() || o_accum.empty() || (!large_dim && q_batch.empty())) return -100; int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; From c375f8f0063c5914fcf65bc9a2b89456e45157fa Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 01:30:15 +0800 Subject: [PATCH 48/53] x86: hoist sdpa avx512 qk tail mask --- src/layer/x86/sdpa_x86.cpp | 52 ++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 28 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 2e10583acf05..6a22bd08c8b1 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -875,6 +875,8 @@ static inline void sdpa_int8_decode_core( static void qk_gemm_avx512(float* S, const float* Q, const float* K, int m, int n, int d, float scale) { + const int tail = d & 15; + const __mmask16 tail_mask = (__mmask16)((1u << tail) - 1); int i = 0; for (; i + 8 <= m; i += 8) { @@ -905,15 +907,14 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, } } - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 kv0 = _mm512_maskz_loadu_ps(mask, k0 + k); - __m512 kv1 = _mm512_maskz_loadu_ps(mask, k1 + k); + __m512 kv0 = _mm512_maskz_loadu_ps(tail_mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(tail_mask, k1 + k); for (int mi = 0; mi < 8; mi++) { - __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } @@ -945,13 +946,12 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, } } - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 kvec = _mm512_maskz_loadu_ps(mask, kptr + k); + __m512 kvec = _mm512_maskz_loadu_ps(tail_mask, kptr + k); for (int mi = 0; mi < 8; mi++) { - __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } @@ -990,15 +990,14 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, } } - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 kv0 = _mm512_maskz_loadu_ps(mask, k0 + k); - __m512 kv1 = _mm512_maskz_loadu_ps(mask, k1 + k); + __m512 kv0 = _mm512_maskz_loadu_ps(tail_mask, k0 + k); + __m512 kv1 = _mm512_maskz_loadu_ps(tail_mask, k1 + k); for (int mi = 0; mi < 4; mi++) { - __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); acc[mi][0] = _mm512_fmadd_ps(qvec, kv0, acc[mi][0]); acc[mi][1] = _mm512_fmadd_ps(qvec, kv1, acc[mi][1]); } @@ -1030,13 +1029,12 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, } } - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 kvec = _mm512_maskz_loadu_ps(mask, kptr + k); + __m512 kvec = _mm512_maskz_loadu_ps(tail_mask, kptr + k); for (int mi = 0; mi < 4; mi++) { - __m512 qvec = _mm512_maskz_loadu_ps(mask, Q + (i + mi) * d + k); + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, Q + (i + mi) * d + k); acc[mi] = _mm512_fmadd_ps(qvec, kvec, acc[mi]); } } @@ -1072,14 +1070,13 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, acc3 = _mm512_fmadd_ps(qvec, _mm512_loadu_ps(k3 + k), acc3); } - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - __m512 qvec = _mm512_maskz_loadu_ps(mask, qptr + k); - acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k0 + k), acc0); - acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k1 + k), acc1); - acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k2 + k), acc2); - acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(mask, k3 + k), acc3); + __m512 qvec = _mm512_maskz_loadu_ps(tail_mask, qptr + k); + acc0 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k0 + k), acc0); + acc1 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k1 + k), acc1); + acc2 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k2 + k), acc2); + acc3 = _mm512_fmadd_ps(qvec, _mm512_maskz_loadu_ps(tail_mask, k3 + k), acc3); } S[i * n + j + 0] = _mm512_comp_reduce_add_ps(acc0) * scale; @@ -1097,10 +1094,9 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, __m512 vacc = _mm512_setzero_ps(); for (; k + 15 < d; k += 16) vacc = _mm512_fmadd_ps(_mm512_loadu_ps(qptr + k), _mm512_loadu_ps(kptr + k), vacc); - if (k < d) + if (tail) { - __mmask16 mask = (__mmask16)((1u << (d - k)) - 1); - vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(mask, qptr + k), _mm512_maskz_loadu_ps(mask, kptr + k), vacc); + vacc = _mm512_fmadd_ps(_mm512_maskz_loadu_ps(tail_mask, qptr + k), _mm512_maskz_loadu_ps(tail_mask, kptr + k), vacc); } S[i * n + j] = _mm512_comp_reduce_add_ps(vacc) * scale; } From 7d5cfae895091b1110e1dbd8a0bda13617594356 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 01:38:31 +0800 Subject: [PATCH 49/53] x86: hoist sdpa avx512 pv prefetch check --- src/layer/x86/sdpa_x86.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 6a22bd08c8b1..96a881056932 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -1600,6 +1600,7 @@ template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 16; + const bool prefetch_v = d >= 512 && n >= 256; int dd = 0; for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) { @@ -1621,7 +1622,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { - if (d >= 512 && n >= 256 && j + 4 < n) + if (prefetch_v && j + 4 < n) _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 vvec[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) @@ -1651,7 +1652,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int for (int j = 0; j < n; j++) { - if (d >= 512 && n >= 256 && j + 4 < n) + if (prefetch_v && j + 4 < n) _mm_prefetch((const char*)(V + (j + 4) * d + dd), _MM_HINT_T1); __m512 pvec = _mm512_set1_ps(pptr[j]); for (int vi = 0; vi < VEC_PER_UNROLL; vi++) From 3ebc91622856b94c769af9171d929e031c861c72 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 01:47:34 +0800 Subject: [PATCH 50/53] x86: avoid sdpa prefill mha q packing --- src/layer/x86/sdpa_x86.cpp | 123 ++++++++++++++++++++----------------- 1 file changed, 68 insertions(+), 55 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 96a881056932..f865ad9bd6de 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -3197,11 +3197,12 @@ static int sdpa_forward_prefill( Mat s_vec(BLOCK_N * BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); Mat o_accum(out_embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); const bool large_dim = embed_dim > 512 && (src_seqlen > 16 || num_heads_per_group > 1); + const bool pack_q = !large_dim && num_heads_per_group > 1; Mat q_batch; - if (!large_dim) + if (pack_q) q_batch.create(embed_dim, BLOCK_M * num_heads_per_group, opt.num_threads, 4u, opt.workspace_allocator); - if (s_vec.empty() || o_accum.empty() || (!large_dim && q_batch.empty())) + if (s_vec.empty() || o_accum.empty() || (pack_q && q_batch.empty())) return -100; int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; @@ -3226,7 +3227,9 @@ static int sdpa_forward_prefill( Mat s_vec_thread = s_vec.channel(get_omp_thread_num()); Mat o_accum_thread = o_accum.channel(get_omp_thread_num()); - Mat q_batch_thread = q_batch.channel(get_omp_thread_num()); + Mat q_batch_thread; + if (pack_q) + q_batch_thread = q_batch.channel(get_omp_thread_num()); // Pre-resolve mask pointers for all heads in this group const float* mask_data[num_heads_per_group]; @@ -3249,7 +3252,15 @@ static int sdpa_forward_prefill( Mat m_state_tile = m_state.channel(idx); Mat l_state_tile = l_state.channel(idx); - if (!large_dim) + Mat query_head_unpacked; + const float* q_data = 0; + if (!large_dim && !pack_q) + { + query_head_unpacked = query_ref.channel(g * num_heads_per_group); + q_data = query_head_unpacked.row(m_start); + } + + if (pack_q) { for (int hq = 0; hq < num_heads_per_group; hq++) { @@ -3288,6 +3299,8 @@ static int sdpa_forward_prefill( if (!large_dim) { + if (pack_q) + q_data = q_batch_thread.row(0); #if NCNN_BF16 if (use_bf16_path) { @@ -3296,7 +3309,7 @@ static int sdpa_forward_prefill( if (ncnn::cpu_support_x86_avx512_bf16()) { qk_gemm_bf16s_avx512bf16(s_ptr, - q_batch_thread.row(0), + q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); } @@ -3304,23 +3317,23 @@ static int sdpa_forward_prefill( #endif { qk_gemm_bf16s_avx512_kernel(s_ptr, - q_batch_thread.row(0), + q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); } #elif __AVX__ qk_gemm_bf16s_avx_kernel(s_ptr, - q_batch_thread.row(0), + q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #elif __SSE2__ qk_gemm_bf16s_sse2_kernel(s_ptr, - q_batch_thread.row(0), + q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #else qk_gemm_bf16s_scalar_kernel(s_ptr, - q_batch_thread.row(0), + q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #endif @@ -3332,167 +3345,167 @@ static int sdpa_forward_prefill( { case 128: #if __AVX512F__ - qk_gemm_specialized_avx512<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<128>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<128>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 64: #if __AVX512F__ - qk_gemm_specialized_avx512<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<64>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<64>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 512: #if __AVX512F__ - qk_gemm_specialized_avx512<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<512>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<512>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 256: #if __AVX512F__ - qk_gemm_specialized_avx512<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<256>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<256>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 32: #if __AVX512F__ - qk_gemm_specialized_avx512<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<32>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<32>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 80: #if __AVX512F__ - qk_gemm_specialized_avx512<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<80>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<80>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 96: #if __AVX512F__ - qk_gemm_specialized_avx512<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<96>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<96>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 160: #if __AVX512F__ - qk_gemm_specialized_avx512<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<160>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<160>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 1024: #if __AVX512F__ - qk_gemm_specialized_avx512<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<1024>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<1024>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 2048: #if __AVX512F__ - qk_gemm_specialized_avx512<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<2048>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<2048>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 4096: #if __AVX512F__ - qk_gemm_specialized_avx512<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<4096>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<4096>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 768: #if __AVX512F__ - qk_gemm_specialized_avx512<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<768>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<768>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 1536: #if __AVX512F__ - qk_gemm_specialized_avx512<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<1536>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<1536>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif case 3072: #if __AVX512F__ - qk_gemm_specialized_avx512<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx512<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __AVX__ - qk_gemm_specialized_avx<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_avx<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #elif __SSE2__ - qk_gemm_specialized_sse2<3072>(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); + qk_gemm_specialized_sse2<3072>(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, _scale); break; #endif default: #if __AVX512F__ - qk_gemm_avx512(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); + qk_gemm_avx512(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #elif __AVX__ - qk_gemm_avx(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); + qk_gemm_avx(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #elif __SSE2__ - qk_gemm_sse2(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); + qk_gemm_sse2(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #else - qk_gemm_scalar(s_ptr, q_batch_thread.row(0), key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); + qk_gemm_scalar(s_ptr, q_data, key_head.row(n_start), block_m * num_heads_per_group, block_n, embed_dim, _scale); #endif break; } From 99ec4dba6f520af9bfa0e5d1bcecc854d5249895 Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 01:54:37 +0800 Subject: [PATCH 51/53] x86: avoid sdpa prefill softmax max copy --- src/layer/x86/sdpa_x86.cpp | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index f865ad9bd6de..83e63643ccb3 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -135,7 +135,7 @@ static inline void vec_zero(float* x, int n) } static inline void softmax_tile(float* P, const float* S, - float* m_vec, float* l_vec, float* scale_out, int m, int n) + float* m_vec, float* l_vec, float* scale_out, unsigned char* changed_out, int m, int n) { for (int i = 0; i < m; i++) { @@ -157,6 +157,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != m_new; l_vec[i] *= scale_factor; __m512 vm_new = _mm512_set1_ps(m_new); @@ -191,6 +192,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != m_new; l_vec[i] *= scale_factor; __m256 vm_new = _mm256_set1_ps(m_new); @@ -222,6 +224,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != m_new; l_vec[i] *= scale_factor; __m128 vm_new = _mm_set1_ps(m_new); @@ -248,6 +251,7 @@ static inline void softmax_tile(float* P, const float* S, m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; + changed_out[i] = m_vec[i] != m_new; l_vec[i] *= scale_factor; float l_add = 0.f; for (int j = 0; j < n; j++) @@ -3529,17 +3533,13 @@ static int sdpa_forward_prefill( float* l_vec = l_state_tile.row(hq); float* o_ptr = o_accum_thread.row(hq * block_m); - float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n); + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n); for (int i = 0; i < block_m; i++) { - if (m_old[i] != m_vec[i]) + if (changed[i]) { vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } @@ -3946,17 +3946,13 @@ static int sdpa_forward_prefill( float* l_vec = l_state_tile.row(hq); float* o_ptr = o_accum_thread.row(hq * block_m); - float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); for (int i = 0; i < block_m; i++) { - if (m_old[i] != m_vec[i]) + if (changed[i]) { vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } @@ -4354,17 +4350,13 @@ static int sdpa_forward_prefill( float* l_vec = l_state_tile.row(hq); float* o_ptr = o_accum_thread.row(hq * block_m); - float m_old[BLOCK_M]; float scale_factors[BLOCK_M]; - for (int i = 0; i < block_m; i++) - { - m_old[i] = m_vec[i]; - } - softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, block_m, block_n2); + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); for (int i = 0; i < block_m; i++) { - if (m_old[i] != m_vec[i]) + if (changed[i]) { vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); } From 4cf724f886873d3cc4738b92285c6be974aaa34c Mon Sep 17 00:00:00 2001 From: futz12 <1391525377@qq.com> Date: Wed, 3 Jun 2026 09:54:04 +0800 Subject: [PATCH 52/53] optimize x86 sdpa prefill kernels --- src/layer/x86/sdpa_x86.cpp | 800 +++++++++++++++++++++----- src/layer/x86/sdpa_x86_avx512bf16.cpp | 26 +- src/layer/x86/sdpa_x86_bf16s.h | 36 +- tests/perf/perf_sdpa_prefill.cpp | 76 ++- tests/perf/perfutil.cpp | 113 +++- tests/test_sdpa.cpp | 13 +- 6 files changed, 873 insertions(+), 191 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 83e63643ccb3..2d67fbd7c210 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -157,7 +157,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; - changed_out[i] = m_vec[i] != m_new; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; l_vec[i] *= scale_factor; __m512 vm_new = _mm512_set1_ps(m_new); @@ -192,7 +192,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; - changed_out[i] = m_vec[i] != m_new; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; l_vec[i] *= scale_factor; __m256 vm_new = _mm256_set1_ps(m_new); @@ -224,7 +224,7 @@ static inline void softmax_tile(float* P, const float* S, float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; - changed_out[i] = m_vec[i] != m_new; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; l_vec[i] *= scale_factor; __m128 vm_new = _mm_set1_ps(m_new); @@ -251,7 +251,7 @@ static inline void softmax_tile(float* P, const float* S, m_new = std::max(m_new, sptr[j]); float scale_factor = expf(m_vec[i] - m_new); scale_out[i] = scale_factor; - changed_out[i] = m_vec[i] != m_new; + changed_out[i] = m_vec[i] != -FLT_MAX && m_vec[i] != m_new; l_vec[i] *= scale_factor; float l_add = 0.f; for (int j = 0; j < n; j++) @@ -1093,7 +1093,6 @@ static void qk_gemm_avx512(float* S, const float* Q, const float* K, { const float* qptr = Q + i * d; const float* kptr = K + j * d; - float sum = 0.f; int k = 0; __m512 vacc = _mm512_setzero_ps(); for (; k + 15 < d; k += 16) @@ -1282,7 +1281,7 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { - if (D >= 512 && n >= 256 && j + 8 <= n) + if (D >= 512 && n >= 128 && j + 8 <= n) { _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); @@ -1390,7 +1389,7 @@ static inline void qk_gemm_specialized_tiled_avx512(float* S, const float* Q, co int j = 0; for (; j + 4 <= n; j += 4) { - if (D >= 512 && n >= 256 && j + 8 <= n) + if (D >= 512 && n >= 128 && j + 8 <= n) { _mm_prefetch((const char*)(K + (j + 4) * D), _MM_HINT_T1); _mm_prefetch((const char*)(K + (j + 5) * D), _MM_HINT_T1); @@ -1597,10 +1596,10 @@ template<> void qk_gemm_specialized_avx512_large_m<4096>(float* S, const float* Q, const float* K, int m, int n, float scale) { - qk_gemm_specialized_tiled_avx512<4096, 2, 2>(S, Q, K, m, n, scale); + qk_gemm_specialized_tiled_avx512<4096, 4, 6>(S, Q, K, m, n, scale); } -template +template static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n, int d) { const int VEC_PER_UNROLL = D_UNROLL / 16; @@ -1622,7 +1621,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int __m512 acc[M_BLOCK][VEC_PER_UNROLL]; for (int mi = 0; mi < M_BLOCK; mi++) for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[mi][vi] = _mm512_loadu_ps(op[mi] + vi * 16); + acc[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + vi * 16); for (int j = 0; j < n; j++) { @@ -1652,7 +1651,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int __m512 acc[VEC_PER_UNROLL]; for (int vi = 0; vi < VEC_PER_UNROLL; vi++) - acc[vi] = _mm512_loadu_ps(optr + vi * 16); + acc[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + vi * 16); for (int j = 0; j < n; j++) { @@ -1683,7 +1682,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int __m512 acc[M_BLOCK]; for (int mi = 0; mi < M_BLOCK; mi++) - acc[mi] = _mm512_loadu_ps(op[mi]); + acc[mi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi]); for (int j = 0; j < n; j++) { @@ -1700,7 +1699,7 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int { float* optr = O + i * d + dd; const float* pptr = P + i * n; - __m512 acc = _mm512_loadu_ps(optr); + __m512 acc = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr); for (int j = 0; j < n; j++) acc = _mm512_fmadd_ps(_mm512_set1_ps(pptr[j]), _mm512_loadu_ps(V + j * d + dd), acc); _mm512_storeu_ps(optr, acc); @@ -1719,6 +1718,11 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int op[mi] = O + (i + mi) * d + dd; pptr[mi] = P + (i + mi) * n; } + if (INIT_ZERO) + { + for (int mi = 0; mi < M_BLOCK; mi++) + op[mi][0] = 0.f; + } for (int j = 0; j < n; j++) { for (int mi = 0; mi < M_BLOCK; mi++) @@ -1729,16 +1733,245 @@ static inline void pv_gemm_avx512(float* O, const float* P, const float* V, int { float* optr = O + i * d + dd; const float* pptr = P + i * n; + if (INIT_ZERO) + optr[0] = 0.f; for (int j = 0; j < n; j++) optr[0] += pptr[j] * V[j * d + dd]; } } } +template +static inline void pv_gemm_2heads_4096_avx512(float* O0, float* O1, const float* P0, const float* P1, const float* V, int m, int n) +{ + const int d = 4096; + const int VEC_PER_UNROLL = D_UNROLL / 16; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op0[M_BLOCK]; + float* op1[M_BLOCK]; + const float* pptr0[M_BLOCK]; + const float* pptr1[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op0[mi] = O0 + (i + mi) * d + dd; + op1[mi] = O1 + (i + mi) * d + dd; + pptr0[mi] = P0 + (i + mi) * n; + pptr1[mi] = P1 + (i + mi) * n; + } + + __m512 acc0[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc1[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op0[mi] + vi * 16); + acc1[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op1[mi] + vi * 16); + } + } + + for (int j = 0; j < n; j++) + { + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[mi][j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = _mm512_fmadd_ps(pvec0, vvec[vi], acc0[mi][vi]); + acc1[mi][vi] = _mm512_fmadd_ps(pvec1, vvec[vi], acc1[mi][vi]); + } + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(op0[mi] + vi * 16, acc0[mi][vi]); + _mm512_storeu_ps(op1[mi] + vi * 16, acc1[mi][vi]); + } + } + } + + for (; i < m; i++) + { + float* optr0 = O0 + i * d + dd; + float* optr1 = O1 + i * d + dd; + const float* pptr0 = P0 + i * n; + const float* pptr1 = P1 + i * n; + + __m512 acc0[VEC_PER_UNROLL]; + __m512 acc1[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr0 + vi * 16); + acc1[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr1 + vi * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); + acc0[vi] = _mm512_fmadd_ps(pvec0, vvec, acc0[vi]); + acc1[vi] = _mm512_fmadd_ps(pvec1, vvec, acc1[vi]); + } + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(optr0 + vi * 16, acc0[vi]); + _mm512_storeu_ps(optr1 + vi * 16, acc1[vi]); + } + } + } +} + +template +static inline void pv_gemm_4heads_4096_avx512(float* O0, float* O1, float* O2, float* O3, const float* P0, const float* P1, const float* P2, const float* P3, const float* V, int m, int n) +{ + const int d = 4096; + const int VEC_PER_UNROLL = D_UNROLL / 16; + int dd = 0; + for (; dd + D_UNROLL - 1 < d; dd += D_UNROLL) + { + int i = 0; + for (; i + M_BLOCK <= m; i += M_BLOCK) + { + float* op0[M_BLOCK]; + float* op1[M_BLOCK]; + float* op2[M_BLOCK]; + float* op3[M_BLOCK]; + const float* pptr0[M_BLOCK]; + const float* pptr1[M_BLOCK]; + const float* pptr2[M_BLOCK]; + const float* pptr3[M_BLOCK]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + op0[mi] = O0 + (i + mi) * d + dd; + op1[mi] = O1 + (i + mi) * d + dd; + op2[mi] = O2 + (i + mi) * d + dd; + op3[mi] = O3 + (i + mi) * d + dd; + pptr0[mi] = P0 + (i + mi) * n; + pptr1[mi] = P1 + (i + mi) * n; + pptr2[mi] = P2 + (i + mi) * n; + pptr3[mi] = P3 + (i + mi) * n; + } + + __m512 acc0[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc1[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc2[M_BLOCK][VEC_PER_UNROLL]; + __m512 acc3[M_BLOCK][VEC_PER_UNROLL]; + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op0[mi] + vi * 16); + acc1[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op1[mi] + vi * 16); + acc2[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op2[mi] + vi * 16); + acc3[mi][vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(op3[mi] + vi * 16); + } + } + + for (int j = 0; j < n; j++) + { + __m512 vvec[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + vvec[vi] = _mm512_loadu_ps(V + j * d + dd + vi * 16); + + for (int mi = 0; mi < M_BLOCK; mi++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[mi][j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[mi][j]); + __m512 pvec2 = _mm512_set1_ps(pptr2[mi][j]); + __m512 pvec3 = _mm512_set1_ps(pptr3[mi][j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[mi][vi] = _mm512_fmadd_ps(pvec0, vvec[vi], acc0[mi][vi]); + acc1[mi][vi] = _mm512_fmadd_ps(pvec1, vvec[vi], acc1[mi][vi]); + acc2[mi][vi] = _mm512_fmadd_ps(pvec2, vvec[vi], acc2[mi][vi]); + acc3[mi][vi] = _mm512_fmadd_ps(pvec3, vvec[vi], acc3[mi][vi]); + } + } + } + + for (int mi = 0; mi < M_BLOCK; mi++) + { + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(op0[mi] + vi * 16, acc0[mi][vi]); + _mm512_storeu_ps(op1[mi] + vi * 16, acc1[mi][vi]); + _mm512_storeu_ps(op2[mi] + vi * 16, acc2[mi][vi]); + _mm512_storeu_ps(op3[mi] + vi * 16, acc3[mi][vi]); + } + } + } + + for (; i < m; i++) + { + float* optr0 = O0 + i * d + dd; + float* optr1 = O1 + i * d + dd; + float* optr2 = O2 + i * d + dd; + float* optr3 = O3 + i * d + dd; + const float* pptr0 = P0 + i * n; + const float* pptr1 = P1 + i * n; + const float* pptr2 = P2 + i * n; + const float* pptr3 = P3 + i * n; + + __m512 acc0[VEC_PER_UNROLL]; + __m512 acc1[VEC_PER_UNROLL]; + __m512 acc2[VEC_PER_UNROLL]; + __m512 acc3[VEC_PER_UNROLL]; + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + acc0[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr0 + vi * 16); + acc1[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr1 + vi * 16); + acc2[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr2 + vi * 16); + acc3[vi] = INIT_ZERO ? _mm512_setzero_ps() : _mm512_loadu_ps(optr3 + vi * 16); + } + + for (int j = 0; j < n; j++) + { + __m512 pvec0 = _mm512_set1_ps(pptr0[j]); + __m512 pvec1 = _mm512_set1_ps(pptr1[j]); + __m512 pvec2 = _mm512_set1_ps(pptr2[j]); + __m512 pvec3 = _mm512_set1_ps(pptr3[j]); + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + __m512 vvec = _mm512_loadu_ps(V + j * d + dd + vi * 16); + acc0[vi] = _mm512_fmadd_ps(pvec0, vvec, acc0[vi]); + acc1[vi] = _mm512_fmadd_ps(pvec1, vvec, acc1[vi]); + acc2[vi] = _mm512_fmadd_ps(pvec2, vvec, acc2[vi]); + acc3[vi] = _mm512_fmadd_ps(pvec3, vvec, acc3[vi]); + } + } + + for (int vi = 0; vi < VEC_PER_UNROLL; vi++) + { + _mm512_storeu_ps(optr0 + vi * 16, acc0[vi]); + _mm512_storeu_ps(optr1 + vi * 16, acc1[vi]); + _mm512_storeu_ps(optr2 + vi * 16, acc2[vi]); + _mm512_storeu_ps(optr3 + vi * 16, acc3[vi]); + } + } + } +} + template static void pv_gemm_avx512(float* O, const float* P, const float* V, int m, int n) { - const int VEC_PER_D = D / 16; int i = 0; for (; i + M_BLOCK <= m; i += M_BLOCK) { @@ -3209,6 +3442,17 @@ static int sdpa_forward_prefill( if (s_vec.empty() || o_accum.empty() || (pack_q && q_batch.empty())) return -100; +#if __AVX512F__ +#if NCNN_BF16 && NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ + const bool bf16_init_zero_path = use_bf16_path && out_embed_dim == 4096 && ncnn::cpu_support_x86_avx512_bf16(); +#else + const bool bf16_init_zero_path = false; +#endif + const bool skip_o_accum_zero = out_embed_dim == 4096 && (!use_bf16_path || bf16_init_zero_path); +#else + const bool skip_o_accum_zero = false; +#endif + int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; // Per-head per-M-tile softmax state for cross-N-tile accumulation @@ -3290,11 +3534,14 @@ static int sdpa_forward_prefill( } float* o_ptr = o_accum_thread.row(hq * block_m); - vec_zero(o_ptr, out_embed_dim * block_m); + if (!skip_o_accum_zero) + vec_zero(o_ptr, out_embed_dim * block_m); } - // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group - for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + // N-outer loop: K/V N-tile is loaded once and reused by all Q heads in this group. + // The large-dim path below has its own N loop over all tiles. + const int n_loop_end = large_dim ? 1 : dst_seqlen; + for (int n_start = 0; n_start < n_loop_end; n_start += BLOCK_N) { int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; int block_n = n_end - n_start; @@ -3554,7 +3801,7 @@ static int sdpa_forward_prefill( if (ncnn::cpu_support_x86_avx512_bf16()) { pv_gemm_bf16s_avx512bf16(o_accum_thread.row(0), s_ptr, value_head.row(n_start), - block_m * num_heads_per_group, block_n, out_embed_dim); + block_m * num_heads_per_group, block_n, out_embed_dim, n_start == 0 && out_embed_dim == 4096); } else #endif @@ -3646,7 +3893,10 @@ static int sdpa_forward_prefill( #endif case 4096: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + if (n_start == 0) + pv_gemm_avx512<2, 128, true>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); + else + pv_gemm_avx512<2, 128>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); break; #elif __AVX__ pv_gemm_avx<2, 64>(o_accum_thread.row(0), s_ptr, value_head.row(n_start), block_m * num_heads_per_group, block_n, 4096); @@ -3710,6 +3960,40 @@ static int sdpa_forward_prefill( int n_end2 = n_start2 + BLOCK_N < dst_seqlen ? n_start2 + BLOCK_N : dst_seqlen; int block_n2 = n_end2 - n_start2; +#if __AVX512F__ + if (!use_bf16_path && !attn_mask && num_heads_per_group == 1 && embed_dim == 4096 && out_embed_dim == 4096) + { + const Mat query_head = query_ref.channel(g); + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr; + + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec = m_state_tile.row(0); + float* l_vec = l_state_tile.row(0); + float* o_ptr = o_accum_thread.row(0); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * 4096, scale_factors[i], 4096); + } + } + + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + + continue; + } +#endif + if (num_group * num_m_tiles == 1) { #pragma omp parallel for num_threads(opt.num_threads) if (opt.num_threads > 1) @@ -3876,7 +4160,7 @@ static int sdpa_forward_prefill( #endif case 4096: #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); @@ -3966,7 +4250,7 @@ static int sdpa_forward_prefill( if (ncnn::cpu_support_x86_avx512_bf16()) { pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + block_m, block_n2, out_embed_dim, n_start2 == 0 && out_embed_dim == 4096); } else #endif @@ -4058,7 +4342,10 @@ static int sdpa_forward_prefill( #endif case 4096: #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); break; #elif __AVX__ pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); @@ -4117,6 +4404,179 @@ static int sdpa_forward_prefill( } else { +#if __AVX512F__ + if (!use_bf16_path && !attn_mask && embed_dim == 4096 && out_embed_dim == 4096 && num_heads_per_group >= 4) + { + int hq = 0; + for (; hq + 3 < num_heads_per_group; hq += 4) + { + int q0 = g * num_heads_per_group + hq; + int q1 = q0 + 1; + int q2 = q0 + 2; + int q3 = q0 + 3; + const Mat query_head0 = query_ref.channel(q0); + const Mat query_head1 = query_ref.channel(q1); + const Mat query_head2 = query_ref.channel(q2); + const Mat query_head3 = query_ref.channel(q3); + + const float* q_ptr0 = query_head0.row(m_start); + const float* q_ptr1 = query_head1.row(m_start); + const float* q_ptr2 = query_head2.row(m_start); + const float* q_ptr3 = query_head3.row(m_start); + float* s_head0 = s_ptr + hq * block_m * block_n2; + float* s_head1 = s_head0 + block_m * block_n2; + float* s_head2 = s_head1 + block_m * block_n2; + float* s_head3 = s_head2 + block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head0, q_ptr0, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head1, q_ptr1, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head2, q_ptr2, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head3, q_ptr3, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec0 = m_state_tile.row(hq); + float* l_vec0 = l_state_tile.row(hq); + float* o_ptr0 = o_accum_thread.row(hq * block_m); + float* m_vec1 = m_state_tile.row(hq + 1); + float* l_vec1 = l_state_tile.row(hq + 1); + float* o_ptr1 = o_accum_thread.row((hq + 1) * block_m); + float* m_vec2 = m_state_tile.row(hq + 2); + float* l_vec2 = l_state_tile.row(hq + 2); + float* o_ptr2 = o_accum_thread.row((hq + 2) * block_m); + float* m_vec3 = m_state_tile.row(hq + 3); + float* l_vec3 = l_state_tile.row(hq + 3); + float* o_ptr3 = o_accum_thread.row((hq + 3) * block_m); + + float scale_factors0[BLOCK_M]; + float scale_factors1[BLOCK_M]; + float scale_factors2[BLOCK_M]; + float scale_factors3[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + unsigned char changed2[BLOCK_M]; + unsigned char changed3[BLOCK_M]; + softmax_tile(s_head0, s_head0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n2); + softmax_tile(s_head1, s_head1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n2); + softmax_tile(s_head2, s_head2, m_vec2, l_vec2, scale_factors2, changed2, block_m, block_n2); + softmax_tile(s_head3, s_head3, m_vec3, l_vec3, scale_factors3, changed3, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + { + vec_scale(o_ptr0 + i * out_embed_dim, scale_factors0[i], out_embed_dim); + } + if (changed1[i]) + { + vec_scale(o_ptr1 + i * out_embed_dim, scale_factors1[i], out_embed_dim); + } + if (changed2[i]) + { + vec_scale(o_ptr2 + i * out_embed_dim, scale_factors2[i], out_embed_dim); + } + if (changed3[i]) + { + vec_scale(o_ptr3 + i * out_embed_dim, scale_factors3[i], out_embed_dim); + } + } + + if (block_n2 < 128) + { + if (n_start2 == 0) + pv_gemm_4heads_4096_avx512<2, 32, true>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_4heads_4096_avx512<2, 32>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + } + else + { + if (n_start2 == 0) + pv_gemm_4heads_4096_avx512<4, 16, true>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_4heads_4096_avx512<4, 16>(o_ptr0, o_ptr1, o_ptr2, o_ptr3, s_head0, s_head1, s_head2, s_head3, value_head.row(n_start2), block_m, block_n2); + } + } + + for (; hq + 1 < num_heads_per_group; hq += 2) + { + int q0 = g * num_heads_per_group + hq; + int q1 = q0 + 1; + const Mat query_head0 = query_ref.channel(q0); + const Mat query_head1 = query_ref.channel(q1); + + const float* q_ptr0 = query_head0.row(m_start); + const float* q_ptr1 = query_head1.row(m_start); + float* s_head0 = s_ptr + hq * block_m * block_n2; + float* s_head1 = s_head0 + block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head0, q_ptr0, key_head.row(n_start2), block_m, block_n2, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head1, q_ptr1, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec0 = m_state_tile.row(hq); + float* l_vec0 = l_state_tile.row(hq); + float* o_ptr0 = o_accum_thread.row(hq * block_m); + float* m_vec1 = m_state_tile.row(hq + 1); + float* l_vec1 = l_state_tile.row(hq + 1); + float* o_ptr1 = o_accum_thread.row((hq + 1) * block_m); + + float scale_factors0[BLOCK_M]; + float scale_factors1[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + softmax_tile(s_head0, s_head0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n2); + softmax_tile(s_head1, s_head1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + { + vec_scale(o_ptr0 + i * out_embed_dim, scale_factors0[i], out_embed_dim); + } + if (changed1[i]) + { + vec_scale(o_ptr1 + i * out_embed_dim, scale_factors1[i], out_embed_dim); + } + } + + if (n_start2 == 0) + pv_gemm_2heads_4096_avx512<4, 32, true>(o_ptr0, o_ptr1, s_head0, s_head1, value_head.row(n_start2), block_m, block_n2); + else + pv_gemm_2heads_4096_avx512<4, 32>(o_ptr0, o_ptr1, s_head0, s_head1, value_head.row(n_start2), block_m, block_n2); + } + + for (; hq < num_heads_per_group; hq++) + { + int q = g * num_heads_per_group + hq; + const Mat query_head = query_ref.channel(q); + const float* q_ptr = query_head.row(m_start); + float* s_head = s_ptr + hq * block_m * block_n2; + + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); + + float* m_vec = m_state_tile.row(hq); + float* l_vec = l_state_tile.row(hq); + float* o_ptr = o_accum_thread.row(hq * block_m); + + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(s_head, s_head, m_vec, l_vec, scale_factors, changed, block_m, block_n2); + + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + { + vec_scale(o_ptr + i * out_embed_dim, scale_factors[i], out_embed_dim); + } + } + + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + } + + continue; + } +#endif + for (int hq = 0; hq < num_heads_per_group; hq++) { int q = g * num_heads_per_group + hq; @@ -4280,7 +4740,7 @@ static int sdpa_forward_prefill( #endif case 4096: #if __AVX512F__ - qk_gemm_avx512(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); + qk_gemm_specialized_avx512_large_m<4096>(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, _scale); break; #elif __AVX__ qk_gemm_avx(s_head, q_ptr, key_head.row(n_start2), block_m, block_n2, 4096, _scale); @@ -4370,7 +4830,7 @@ static int sdpa_forward_prefill( if (ncnn::cpu_support_x86_avx512_bf16()) { pv_gemm_bf16s_avx512bf16(o_ptr, s_head, value_head.row(n_start2), - block_m, block_n2, out_embed_dim); + block_m, block_n2, out_embed_dim, n_start2 == 0 && out_embed_dim == 4096); } else #endif @@ -4462,7 +4922,10 @@ static int sdpa_forward_prefill( #endif case 4096: #if __AVX512F__ - pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + if (n_start2 == 0) + pv_gemm_avx512<4, 64, true>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); break; #elif __AVX__ pv_gemm_avx<2, 64>(o_ptr, s_head, value_head.row(n_start2), block_m, block_n2, 4096); @@ -4798,6 +5261,158 @@ static int sdpa_prefill_int8_x86( { const int BLOCK_M = 64; const int BLOCK_N = 128; +#if __AVX512F__ + const bool skip_o_accum_zero = !kv_cache && out_embed_dim == 4096; +#else + const bool skip_o_accum_zero = false; +#endif + +#if __AVX512F__ + if (!kv_cache && !attn_mask && out_embed_dim == 4096 && num_heads_per_group >= 2) + { + const int num_group = num_heads / num_heads_per_group; + const int num_m_tiles = (src_seqlen + BLOCK_M - 1) / BLOCK_M; + const int num_head_pairs = (num_heads_per_group + 1) / 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int task = 0; task < num_group * num_m_tiles * num_head_pairs; task++) + { + int pair_idx = task % num_head_pairs; + int tmp = task / num_head_pairs; + int m_tile = tmp % num_m_tiles; + int g = tmp / num_m_tiles; + + int hq0 = pair_idx * 2; + int hq1 = hq0 + 1; + int q0 = g * num_heads_per_group + hq0; + int q1 = q0 + 1; + bool has_q1 = hq1 < num_heads_per_group; + + const Mat query_head0 = query.channel(q0); + const Mat query_head1 = has_q1 ? query.channel(q1) : Mat(); + const Mat key_int8_head = key_int8.channel(g); + const Mat key_scales_head = key_scales.channel(g); + const Mat value_head = value.channel(g); + Mat top_blob_head0 = top_blob.channel(q0); + Mat top_blob_head1 = has_q1 ? top_blob.channel(q1) : Mat(); + + Mat o_accum_head = o_accum.channel(get_omp_thread_num()); + float* s_vec_ptr0 = s_vec.row(get_omp_thread_num()); + float* s_vec_ptr1 = s_vec_ptr0 + BLOCK_M * BLOCK_N; + float* p_vec_ptr0 = p_vec.row(get_omp_thread_num()); + float* p_vec_ptr1 = p_vec_ptr0 + BLOCK_M * BLOCK_N; + Mat q_int8_tile_head = q_int8_tile.channel(get_omp_thread_num()); + Mat q_scales_tile_head = q_scales_tile.channel(get_omp_thread_num()); + + int m_start = m_tile * BLOCK_M; + int m_end = m_start + BLOCK_M < src_seqlen ? m_start + BLOCK_M : src_seqlen; + int block_m = m_end - m_start; + + for (int i = 0; i < block_m; i++) + { + dynamic_quantize_rowwise(query_head0.row(m_start + i), q_int8_tile_head.row(i), q_scales_tile_head.row(i), embed_dim); + q_scales_tile_head.row(i)[0] = 1.f / q_scales_tile_head.row(i)[0]; + + if (has_q1) + { + dynamic_quantize_rowwise(query_head1.row(m_start + i), q_int8_tile_head.row(BLOCK_M + i), q_scales_tile_head.row(BLOCK_M + i), embed_dim); + q_scales_tile_head.row(BLOCK_M + i)[0] = 1.f / q_scales_tile_head.row(BLOCK_M + i)[0]; + } + } + + float m_vec0[BLOCK_M]; + float l_vec0[BLOCK_M]; + float m_vec1[BLOCK_M]; + float l_vec1[BLOCK_M]; + for (int i = 0; i < block_m; i++) + { + m_vec0[i] = -FLT_MAX; + l_vec0[i] = 0.f; + m_vec1[i] = -FLT_MAX; + l_vec1[i] = 0.f; + } + + float* o_ptr0 = o_accum_head.row(0); + float* o_ptr1 = o_accum_head.row(BLOCK_M); + + for (int n_start = 0; n_start < dst_seqlen; n_start += BLOCK_N) + { + int n_end = n_start + BLOCK_N < dst_seqlen ? n_start + BLOCK_N : dst_seqlen; + int block_n = n_end - n_start; + + qk_int8_gemm_tiled(s_vec_ptr0, + q_int8_tile_head.row(0), + key_int8_head.row(n_start), + q_scales_tile_head.row(0), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + + if (has_q1) + { + qk_int8_gemm_tiled(s_vec_ptr1, + q_int8_tile_head.row(BLOCK_M), + key_int8_head.row(n_start), + q_scales_tile_head.row(BLOCK_M), + key_scales_head.row(n_start), + block_m, block_n, embed_dim, _scale); + } + + float scale_factors0[BLOCK_M]; + unsigned char changed0[BLOCK_M]; + softmax_tile(p_vec_ptr0, s_vec_ptr0, m_vec0, l_vec0, scale_factors0, changed0, block_m, block_n); + + float scale_factors1[BLOCK_M]; + unsigned char changed1[BLOCK_M]; + if (has_q1) + softmax_tile(p_vec_ptr1, s_vec_ptr1, m_vec1, l_vec1, scale_factors1, changed1, block_m, block_n); + + for (int i = 0; i < block_m; i++) + { + if (changed0[i]) + vec_scale(o_ptr0 + i * 4096, scale_factors0[i], 4096); + if (has_q1 && changed1[i]) + vec_scale(o_ptr1 + i * 4096, scale_factors1[i], 4096); + } + + if (has_q1) + { + if (n_start == 0) + pv_gemm_2heads_4096_avx512<16, 16, true>(o_ptr0, o_ptr1, p_vec_ptr0, p_vec_ptr1, value_head.row(n_start), block_m, block_n); + else + pv_gemm_2heads_4096_avx512<8, 16>(o_ptr0, o_ptr1, p_vec_ptr0, p_vec_ptr1, value_head.row(n_start), block_m, block_n); + } + else + { + if (n_start == 0) + pv_gemm_avx512<4, 64, true>(o_ptr0, p_vec_ptr0, value_head.row(n_start), block_m, block_n, 4096); + else + pv_gemm_avx512<4, 64>(o_ptr0, p_vec_ptr0, value_head.row(n_start), block_m, block_n, 4096); + } + } + + for (int i = 0; i < block_m; i++) + { + const float* optr0 = o_ptr0 + i * 4096; + float* outptr0 = top_blob_head0.row(m_start + i); + __m512 vinv_l0 = _mm512_set1_ps(1.f / l_vec0[i]); + for (int k = 0; k < 4096; k += 16) + _mm512_storeu_ps(outptr0 + k, _mm512_mul_ps(_mm512_loadu_ps(optr0 + k), vinv_l0)); + + if (has_q1) + { + const float* optr1 = o_ptr1 + i * 4096; + float* outptr1 = top_blob_head1.row(m_start + i); + __m512 vinv_l1 = _mm512_set1_ps(1.f / l_vec1[i]); + for (int k = 0; k < 4096; k += 16) + _mm512_storeu_ps(outptr1 + k, _mm512_mul_ps(_mm512_loadu_ps(optr1 + k), vinv_l1)); + } + } + } + + return 0; + } +#endif + #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < num_heads; q++) { @@ -4842,7 +5457,8 @@ static int sdpa_prefill_int8_x86( for (int i = 0; i < block_m; i++) { float* optr = o_accum_head.row(i); - vec_zero(optr, out_embed_dim); + if (!skip_o_accum_zero) + vec_zero(optr, out_embed_dim); } float m_vec[BLOCK_M]; @@ -4877,113 +5493,23 @@ static int sdpa_prefill_int8_x86( block_m, block_n, embed_dim, _scale); } - for (int i = 0; i < block_m; i++) + if (attn_mask) { - float* s_row = (block_m == 1) ? s_vec_ptr : (s_vec_ptr + i * block_n); - float* p_row = (block_m == 1) ? p_vec_ptr : (p_vec_ptr + i * block_n); - - if (attn_mask) - decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); - -#if __AVX512F__ - __m512 vmax = _mm512_set1_ps(m_vec[i]); - int j = 0; - for (; j + 15 < block_n; j += 16) - vmax = _mm512_max_ps(vmax, _mm512_loadu_ps(s_row + j)); - if (j < block_n) + for (int i = 0; i < block_m; i++) { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - vmax = _mm512_max_ps(vmax, _mm512_mask_loadu_ps(_mm512_set1_ps(-FLT_MAX), mask, s_row + j)); + float* s_row = s_vec_ptr + i * block_n; + decode_mask_vec(s_row, mask_head.row(m_start + i) + n_start, block_n); } - float m_new = _mm512_comp_reduce_max_ps(vmax); -#elif __AVX__ - __m256 vmax = _mm256_set1_ps(m_vec[i]); - int j = 0; - for (; j + 7 < block_n; j += 8) - vmax = _mm256_max_ps(vmax, _mm256_loadu_ps(s_row + j)); - float m_new = _mm256_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#elif __SSE2__ - __m128 vmax = _mm_set1_ps(m_vec[i]); - int j = 0; - for (; j + 3 < block_n; j += 4) - vmax = _mm_max_ps(vmax, _mm_loadu_ps(s_row + j)); - float m_new = _mm_reduce_max_ps(vmax); - for (; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#else - float m_new = m_vec[i]; - for (int j = 0; j < block_n; j++) - m_new = std::max(m_new, s_row[j]); -#endif - - float scale_factor = expf(m_vec[i] - m_new); - float l_new = l_vec[i] * scale_factor; + } - float* optr = o_accum_head.row(i); - vec_scale(optr, scale_factor, out_embed_dim); + float scale_factors[BLOCK_M]; + unsigned char changed[BLOCK_M]; + softmax_tile(p_vec_ptr, s_vec_ptr, m_vec, l_vec, scale_factors, changed, block_m, block_n); -#if __AVX512F__ - __m512 vm_new = _mm512_set1_ps(m_new); - __m512 vsum = _mm512_setzero_ps(); - j = 0; - for (; j + 15 < block_n; j += 16) - { - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_loadu_ps(s_row + j), vm_new)); - _mm512_storeu_ps(p_row + j, pvec); - vsum = _mm512_add_ps(vsum, pvec); - } - if (j < block_n) - { - __mmask16 mask = (__mmask16)((1u << (block_n - j)) - 1); - __m512 pvec = exp512_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, s_row + j), vm_new)); - _mm512_mask_storeu_ps(p_row + j, mask, pvec); - vsum = _mm512_mask_add_ps(vsum, mask, vsum, pvec); - } - l_new += _mm512_comp_reduce_add_ps(vsum); -#elif __AVX__ - __m256 vm_new = _mm256_set1_ps(m_new); - __m256 vsum = _mm256_setzero_ps(); - j = 0; - for (; j + 7 < block_n; j += 8) - { - __m256 pvec = exp256_ps(_mm256_sub_ps(_mm256_loadu_ps(s_row + j), vm_new)); - _mm256_storeu_ps(p_row + j, pvec); - vsum = _mm256_add_ps(vsum, pvec); - } - l_new += _mm256_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#elif __SSE2__ - __m128 vm_new = _mm_set1_ps(m_new); - __m128 vsum = _mm_setzero_ps(); - j = 0; - for (; j + 3 < block_n; j += 4) - { - __m128 pvec = exp_ps(_mm_sub_ps(_mm_loadu_ps(s_row + j), vm_new)); - _mm_storeu_ps(p_row + j, pvec); - vsum = _mm_add_ps(vsum, pvec); - } - l_new += _mm_reduce_add_ps(vsum); - for (; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#else - for (int j = 0; j < block_n; j++) - { - p_row[j] = expf(s_row[j] - m_new); - l_new += p_row[j]; - } -#endif - - m_vec[i] = m_new; - l_vec[i] = l_new; + for (int i = 0; i < block_m; i++) + { + if (changed[i]) + vec_scale(o_accum_head.row(i), scale_factors[i], out_embed_dim); } if (kv_cache) @@ -5065,7 +5591,10 @@ static int sdpa_prefill_int8_x86( #endif case 4096: #if __AVX512F__ - pv_gemm_avx512<2, 128>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + if (n_start == 0) + pv_gemm_avx512<4, 64, true>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); + else + pv_gemm_avx512<4, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); break; #elif __AVX__ pv_gemm_avx<2, 64>(o_accum_head.row(0), p_vec_ptr, value.channel(q / num_heads_per_group).row(n_start), block_m, block_n, 4096); @@ -5820,11 +6349,14 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to num_group, dst_seqlen, embed_dim, out_embed_dim, v_num_blocks, cache_valid, past_seqlen, kv_cache, opt); - Mat o_accum(out_embed_dim, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat s_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat p_vec(BLOCK_N * BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); - Mat q_int8_tile(embed_dim, BLOCK_M, opt.num_threads, 1u, opt.workspace_allocator); - Mat q_scales_tile(1, BLOCK_M, opt.num_threads, 4u, opt.workspace_allocator); + const bool use_int8_prefill_2heads_workspace = src_seqlen > 1 && !kv_cache && !attn_mask && out_embed_dim == 4096 && num_heads_per_group >= 2; + const int int8_prefill_workspace_heads = use_int8_prefill_2heads_workspace ? 2 : 1; + + Mat o_accum(out_embed_dim, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat s_vec(BLOCK_N * BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat p_vec(BLOCK_N * BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); + Mat q_int8_tile(embed_dim, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 1u, opt.workspace_allocator); + Mat q_scales_tile(1, BLOCK_M * int8_prefill_workspace_heads, opt.num_threads, 4u, opt.workspace_allocator); if (o_accum.empty() || s_vec.empty() || p_vec.empty() || q_int8_tile.empty() || q_scales_tile.empty()) return -100; diff --git a/src/layer/x86/sdpa_x86_avx512bf16.cpp b/src/layer/x86/sdpa_x86_avx512bf16.cpp index 6befdf52c33b..2379c43e92bd 100644 --- a/src/layer/x86/sdpa_x86_avx512bf16.cpp +++ b/src/layer/x86/sdpa_x86_avx512bf16.cpp @@ -37,12 +37,12 @@ template void qk_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const template void qk_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int, float); template void qk_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int, float); -template void pv_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int); -template void pv_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int); -template void pv_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int); -template void pv_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int); -template void pv_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int); -template void pv_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int); +template void pv_gemm_bf16s_avx512bf16_kernel_t<64>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<128>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<256>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<512>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<1024>(float*, const float*, const unsigned short*, int, int, bool); +template void pv_gemm_bf16s_avx512bf16_kernel_t<4096>(float*, const float*, const unsigned short*, int, int, bool); void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale) { @@ -77,27 +77,27 @@ void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const uns qk_gemm_bf16s_avx512bf16_qbf16_kernel(S, Q, K, m, n, d, scale); } -void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d) +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d, bool init_zero) { switch (d) { case 64: - pv_gemm_bf16s_avx512bf16_kernel_t<64>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<64>(O, P, V, m, n, init_zero); break; case 128: - pv_gemm_bf16s_avx512bf16_kernel_t<128>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<128>(O, P, V, m, n, init_zero); break; case 256: - pv_gemm_bf16s_avx512bf16_kernel_t<256>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<256>(O, P, V, m, n, init_zero); break; case 512: - pv_gemm_bf16s_avx512bf16_kernel_t<512>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<512>(O, P, V, m, n, init_zero); break; case 1024: - pv_gemm_bf16s_avx512bf16_kernel_t<1024>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<1024>(O, P, V, m, n, init_zero); break; case 4096: - pv_gemm_bf16s_avx512bf16_kernel_t<4096>(O, P, V, m, n); + pv_gemm_bf16s_avx512bf16_kernel_t<4096>(O, P, V, m, n, init_zero); break; default: pv_gemm_bf16s_avx512bf16_kernel(O, P, V, m, n, d); diff --git a/src/layer/x86/sdpa_x86_bf16s.h b/src/layer/x86/sdpa_x86_bf16s.h index f08f2e1cdfd0..52721356c8bb 100644 --- a/src/layer/x86/sdpa_x86_bf16s.h +++ b/src/layer/x86/sdpa_x86_bf16s.h @@ -23,7 +23,7 @@ void decode_qk_dot_bf16s_avx512bf16(float* s, const float* q, const unsigned sho void decode_pv_gemv_bf16s_avx512bf16(float* out, const float* s, const unsigned short* V, int n_start, int block_n, int out_d); void qk_gemm_bf16s_avx512bf16(float* S, const float* Q, const unsigned short* K, int m, int n, int d, float scale); void qk_gemm_bf16s_avx512bf16_qbf16(float* S, const unsigned short* Q, const unsigned short* K, int m, int n, int d, float scale); -void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d); +void pv_gemm_bf16s_avx512bf16(float* O, const float* P, const unsigned short* V, int m, int n, int d, bool init_zero = false); #endif // --------------------------------------------------------------------------- @@ -2265,7 +2265,7 @@ static void qk_gemm_bf16s_avx512bf16_kernel_t(float* S, const float* Q, const un } template -static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n) +static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n, bool init_zero) { unsigned short* p_bf16 = (unsigned short*)_mm_malloc(m * n * sizeof(unsigned short), 64); for (int i = 0; i < m; i++) @@ -2296,8 +2296,8 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un __m512 acc[4][2]; for (int mi = 0; mi < 4; mi++) { - acc[mi][0] = _mm512_loadu_ps(op[mi] + 0); - acc[mi][1] = _mm512_loadu_ps(op[mi] + 16); + acc[mi][0] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + 0); + acc[mi][1] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi] + 16); } for (int j = 0; j < n; j++) @@ -2327,10 +2327,10 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un { float* op0 = O + i * D + dd; float* op1 = O + (i + 1) * D + dd; - __m512 acc00 = _mm512_loadu_ps(op0 + 0); - __m512 acc01 = _mm512_loadu_ps(op0 + 16); - __m512 acc10 = _mm512_loadu_ps(op1 + 0); - __m512 acc11 = _mm512_loadu_ps(op1 + 16); + __m512 acc00 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0 + 0); + __m512 acc01 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0 + 16); + __m512 acc10 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1 + 0); + __m512 acc11 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1 + 16); for (int j = 0; j < n; j++) { const unsigned short* vptr = V + j * D + dd; @@ -2355,8 +2355,8 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un for (; i < m; i++) { float* optr = O + i * D + dd; - __m512 acc0 = _mm512_loadu_ps(optr + 0); - __m512 acc1 = _mm512_loadu_ps(optr + 16); + __m512 acc0 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + 0); + __m512 acc1 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr + 16); for (int j = 0; j < n; j++) { const unsigned short* vptr = V + j * D + dd; @@ -2384,7 +2384,7 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un op[mi] = O + (i + mi) * D + dd; __m512 acc[4]; for (int mi = 0; mi < 4; mi++) - acc[mi] = _mm512_loadu_ps(op[mi]); + acc[mi] = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op[mi]); for (int j = 0; j < n; j++) { @@ -2404,8 +2404,8 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un { float* op0 = O + i * D + dd; float* op1 = O + (i + 1) * D + dd; - __m512 acc0 = _mm512_loadu_ps(op0); - __m512 acc1 = _mm512_loadu_ps(op1); + __m512 acc0 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op0); + __m512 acc1 = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(op1); for (int j = 0; j < n; j++) { __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); @@ -2423,7 +2423,7 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un for (; i < m; i++) { float* optr = O + i * D + dd; - __m512 acc = _mm512_loadu_ps(optr); + __m512 acc = init_zero ? _mm512_setzero_ps() : _mm512_loadu_ps(optr); for (int j = 0; j < n; j++) { __m256i v0 = _mm256_loadu_si256((const __m256i*)(V + j * D + dd)); @@ -2441,7 +2441,7 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un for (int i = 0; i < m; i++) { float* optr = O + i * D + dd; - float acc = optr[0]; + float acc = init_zero ? 0.f : optr[0]; for (int j = 0; j < n; j++) acc += bfloat16_to_float32(p_bf16[i * n + j]) * bfloat16_to_float32(V[j * D + dd]); optr[0] = acc; @@ -2451,6 +2451,12 @@ static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const un _mm_free(p_bf16); } +template +static void pv_gemm_bf16s_avx512bf16_kernel_t(float* O, const float* P, const unsigned short* V, int m, int n) +{ + pv_gemm_bf16s_avx512bf16_kernel_t(O, P, V, m, n, false); +} + #endif // __AVX512F__ && __AVX512BF16__ #endif // SDPA_X86_BF16S_H diff --git a/tests/perf/perf_sdpa_prefill.cpp b/tests/perf/perf_sdpa_prefill.cpp index fae0eee19b53..02ef9cd8c892 100644 --- a/tests/perf/perf_sdpa_prefill.cpp +++ b/tests/perf/perf_sdpa_prefill.cpp @@ -4,6 +4,7 @@ #include "perfutil.h" #include +#include static bool match_env_int(const char* name, int value) { @@ -14,6 +15,34 @@ static bool match_env_int(const char* name, int value) return atoi(s) == value; } +static bool has_env_int(const char* name) +{ + const char* s = getenv(name); + return s && s[0]; +} + +static bool match_env_string(const char* name, const char* value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return true; + + return strcmp(s, value) == 0; +} + +static bool should_run_fp_prefill() +{ + return match_env_string("NCNN_PERF_DTYPE", "fp32") + || match_env_string("NCNN_PERF_DTYPE", "fp16ps") + || match_env_string("NCNN_PERF_DTYPE", "fp16psa") + || match_env_string("NCNN_PERF_DTYPE", "bf16ps"); +} + +static bool should_run_int8_prefill() +{ + return match_env_string("NCNN_PERF_DTYPE", "int8"); +} + static bool should_run_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) { return match_env_int("NCNN_PERF_SDPA_EMBED", embed_dim) @@ -22,6 +51,14 @@ static bool should_run_prefill(int embed_dim, int num_heads, int num_groups, int && match_env_int("NCNN_PERF_SDPA_SEQLEN", src_seqlen); } +static bool should_run_extended_prefill() +{ + return has_env_int("NCNN_PERF_SDPA_EMBED") + && has_env_int("NCNN_PERF_SDPA_HEADS") + && has_env_int("NCNN_PERF_SDPA_GROUPS") + && has_env_int("NCNN_PERF_SDPA_SEQLEN"); +} + // prefill phase: larger src_seqlen, no kv_cache (past_seqlen=0) static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int src_seqlen) { @@ -44,9 +81,12 @@ static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int inputs[1] = PerfMat(embed_dim, cur_seqlen, num_groups); // k inputs[2] = PerfMat(out_embed_dim, cur_seqlen, num_groups); // v - perf_layer("SDPA", pd, weights, inputs, 1, - "embed=%d heads=%d groups=%d seqlen=%d", - embed_dim, num_heads, num_groups, src_seqlen); + if (should_run_fp_prefill()) + { + perf_layer("SDPA", pd, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); + } // int8 variant ncnn::ParamDict pd_int8; @@ -54,9 +94,12 @@ static void perf_sdpa_prefill(int embed_dim, int num_heads, int num_groups, int pd_int8.set(6, 0.f); // scale = 0 pd_int8.set(7, 0); // kv_cache = 0 pd_int8.set(18, 2); // int8_scale_term - perf_layer_int8("SDPA", pd_int8, weights, inputs, 1, - "embed=%d heads=%d groups=%d seqlen=%d", - embed_dim, num_heads, num_groups, src_seqlen); + if (should_run_int8_prefill()) + { + perf_layer_int8("SDPA", pd_int8, weights, inputs, 1, + "embed=%d heads=%d groups=%d seqlen=%d", + embed_dim, num_heads, num_groups, src_seqlen); + } } int main() @@ -88,6 +131,12 @@ int main() // GQA/MQA configurations // GQA: num_groups < num_heads + perf_sdpa_prefill(4096, 32, 16, 16); + perf_sdpa_prefill(4096, 32, 16, 32); + perf_sdpa_prefill(4096, 32, 16, 64); + perf_sdpa_prefill(4096, 32, 8, 16); + perf_sdpa_prefill(4096, 32, 8, 32); + perf_sdpa_prefill(4096, 32, 8, 64); perf_sdpa_prefill(4096, 32, 4, 16); perf_sdpa_prefill(4096, 32, 4, 32); perf_sdpa_prefill(4096, 32, 4, 64); @@ -97,5 +146,20 @@ int main() perf_sdpa_prefill(4096, 32, 1, 32); perf_sdpa_prefill(4096, 32, 1, 64); + // Longer large-model cases are opt-in through NCNN_PERF_SDPA_* filters. + if (should_run_extended_prefill()) + { + perf_sdpa_prefill(4096, 32, 32, 128); + perf_sdpa_prefill(4096, 32, 32, 256); + perf_sdpa_prefill(4096, 32, 16, 128); + perf_sdpa_prefill(4096, 32, 16, 256); + perf_sdpa_prefill(4096, 32, 8, 128); + perf_sdpa_prefill(4096, 32, 8, 256); + perf_sdpa_prefill(4096, 32, 4, 128); + perf_sdpa_prefill(4096, 32, 4, 256); + perf_sdpa_prefill(4096, 32, 1, 128); + perf_sdpa_prefill(4096, 32, 1, 256); + } + return 0; } diff --git a/tests/perf/perfutil.cpp b/tests/perf/perfutil.cpp index 79c81b7d9841..61a58a441533 100644 --- a/tests/perf/perfutil.cpp +++ b/tests/perf/perfutil.cpp @@ -14,6 +14,10 @@ #include #include +#if defined(__linux__) +#include +#endif + #if NCNN_VULKAN #include "command.h" #include "gpu.h" @@ -26,6 +30,60 @@ #define PERF_RUN_COUNT 20 #define PERF_TARGET_MIN_MS 5.0 +static int perf_env_int(const char* name, int default_value, int min_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + + int v = atoi(s); + return v < min_value ? min_value : v; +} + +static double perf_env_double(const char* name, double default_value, double min_value) +{ + const char* s = getenv(name); + if (!s || !s[0]) + return default_value; + + double v = atof(s); + return v < min_value ? min_value : v; +} + +static void setup_perf_cpu_affinity() +{ + static bool initialized = false; + if (initialized) + return; + initialized = true; + +#if defined(__linux__) + const char* s = getenv("NCNN_PERF_CPU_AFFINITY"); + if (!s || !s[0]) + return; + + int cpu = atoi(s); + if (cpu < 0) + return; + if (cpu >= CPU_SETSIZE) + { + fprintf(stderr, "NCNN_PERF_CPU_AFFINITY=%d exceeds CPU_SETSIZE=%d\n", cpu, CPU_SETSIZE); + return; + } + + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpu, &cpuset); + + if (sched_setaffinity(0, sizeof(cpuset), &cpuset) != 0) + { + fprintf(stderr, "NCNN_PERF_CPU_AFFINITY=%d sched_setaffinity failed\n", cpu); + } +#else + (void)getenv("NCNN_PERF_CPU_AFFINITY"); +#endif +} + // benchmark result for a single test case struct PerfResult { @@ -268,6 +326,12 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, int forced_inner_loops, PerfResult& result) { + setup_perf_cpu_affinity(); + + const int warmup_count = perf_env_int("NCNN_PERF_WARMUP_COUNT", PERF_WARMUP_COUNT, 1); + const int run_count = perf_env_int("NCNN_PERF_RUN_COUNT", PERF_RUN_COUNT, 1); + const double target_min_ms = perf_env_double("NCNN_PERF_TARGET_MIN_MS", PERF_TARGET_MIN_MS, 0.0); + ncnn::Layer* op = ncnn::create_layer_cpu(layer_type); if (!op) { @@ -291,7 +355,7 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, // warmup and calibrate inner loop count from warmup min time double warmup_min_ms = DBL_MAX; - for (int i = 0; i < PERF_WARMUP_COUNT; i++) + for (int i = 0; i < warmup_count; i++) { double t0 = ncnn::get_current_time(); int ret = run_layer_forward_cpu(op, converted, top_blob_count, opt); @@ -315,19 +379,19 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, { // calibrate inner loop count to power of 10, so total time >= PERF_TARGET_MIN_MS inner_loops = 1; - if (warmup_min_ms > 0 && warmup_min_ms < PERF_TARGET_MIN_MS) + if (warmup_min_ms > 0 && warmup_min_ms < target_min_ms) { - while (inner_loops * warmup_min_ms < PERF_TARGET_MIN_MS) + while (inner_loops * warmup_min_ms < target_min_ms) inner_loops *= 10; } } - double* times = new double[PERF_RUN_COUNT]; + double* times = new double[run_count]; double time_sum = 0; double time_min_val = DBL_MAX; double time_max_val = -DBL_MAX; - for (int i = 0; i < PERF_RUN_COUNT; i++) + for (int i = 0; i < run_count; i++) { double start = ncnn::get_current_time(); @@ -345,16 +409,16 @@ static int perf_layer_cpu(const char* layer_type, const ncnn::ParamDict& pd, if (t > time_max_val) time_max_val = t; } - sort_doubles(times, PERF_RUN_COUNT); + sort_doubles(times, run_count); double time_median_val; - if (PERF_RUN_COUNT % 2 == 0) - time_median_val = (times[PERF_RUN_COUNT / 2 - 1] + times[PERF_RUN_COUNT / 2]) / 2.0; + if (run_count % 2 == 0) + time_median_val = (times[run_count / 2 - 1] + times[run_count / 2]) / 2.0; else - time_median_val = times[PERF_RUN_COUNT / 2]; + time_median_val = times[run_count / 2]; result.time_min = time_min_val; result.time_max = time_max_val; - result.time_avg = time_sum / PERF_RUN_COUNT; + result.time_avg = time_sum / run_count; result.time_median = time_median_val; result.loop_count = inner_loops; @@ -410,6 +474,11 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, int forced_inner_loops, PerfResult& result) { + const int warmup_count = perf_env_int("NCNN_PERF_GPU_WARMUP_COUNT", PERF_GPU_WARMUP_COUNT, 1); + const int warmup_batch = perf_env_int("NCNN_PERF_GPU_WARMUP_BATCH", PERF_GPU_WARMUP_BATCH, 1); + const int run_count = perf_env_int("NCNN_PERF_RUN_COUNT", PERF_RUN_COUNT, 1); + const double target_min_ms = perf_env_double("NCNN_PERF_TARGET_MIN_MS", PERF_TARGET_MIN_MS, 0.0); + ncnn::Layer* op = ncnn::create_layer_vulkan(layer_type); if (!op) { @@ -516,17 +585,17 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, // warmup and calibrate inner loop count from warmup min time // batch multiple forwards per submit to amortize submit_and_wait overhead double warmup_min_ms = DBL_MAX; - for (int i = 0; i < PERF_GPU_WARMUP_COUNT; i++) + for (int i = 0; i < warmup_count; i++) { ncnn::VkCompute cmd(vkdev); - for (int b = 0; b < PERF_GPU_WARMUP_BATCH; b++) + for (int b = 0; b < warmup_batch; b++) { run_layer_forward_gpu(op, vk_inputs, top_blob_count, cmd, opt); } double t0 = ncnn::get_current_time(); cmd.submit_and_wait(); double t1 = ncnn::get_current_time(); - double t = (t1 - t0) / PERF_GPU_WARMUP_BATCH; + double t = (t1 - t0) / warmup_batch; if (t < warmup_min_ms) warmup_min_ms = t; } @@ -539,20 +608,20 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, { // calibrate inner loop count to power of 10, so total time >= PERF_TARGET_MIN_MS inner_loops = 1; - if (warmup_min_ms > 0 && warmup_min_ms < PERF_TARGET_MIN_MS) + if (warmup_min_ms > 0 && warmup_min_ms < target_min_ms) { - while (inner_loops * warmup_min_ms < PERF_TARGET_MIN_MS) + while (inner_loops * warmup_min_ms < target_min_ms) inner_loops *= 10; } } // record inner_loops forwards into one command buffer, single submit // this measures pure GPU kernel time, excluding per-launch overhead - double* times = new double[PERF_RUN_COUNT]; + double* times = new double[run_count]; double time_sum = 0; double time_min_val = DBL_MAX; double time_max_val = -DBL_MAX; - for (int i = 0; i < PERF_RUN_COUNT; i++) + for (int i = 0; i < run_count; i++) { ncnn::VkCompute cmd(vkdev); for (int k = 0; k < inner_loops; k++) @@ -572,16 +641,16 @@ static int perf_layer_gpu(const char* layer_type, const ncnn::ParamDict& pd, if (t > time_max_val) time_max_val = t; } - sort_doubles(times, PERF_RUN_COUNT); + sort_doubles(times, run_count); double time_median_val; - if (PERF_RUN_COUNT % 2 == 0) - time_median_val = (times[PERF_RUN_COUNT / 2 - 1] + times[PERF_RUN_COUNT / 2]) / 2.0; + if (run_count % 2 == 0) + time_median_val = (times[run_count / 2 - 1] + times[run_count / 2]) / 2.0; else - time_median_val = times[PERF_RUN_COUNT / 2]; + time_median_val = times[run_count / 2]; result.time_min = time_min_val; result.time_max = time_max_val; - result.time_avg = time_sum / PERF_RUN_COUNT; + result.time_avg = time_sum / run_count; result.time_median = time_median_val; result.loop_count = inner_loops; diff --git a/tests/test_sdpa.cpp b/tests/test_sdpa.cpp index c043c6b731bd..8ff4cc900a9a 100644 --- a/tests/test_sdpa.cpp +++ b/tests/test_sdpa.cpp @@ -117,6 +117,17 @@ static int test_sdpa_1() || test_sdpa_int8(RandomMat(28, 17, 15), RandomMat(28, 127, 5), RandomMat(32, 127, 5), 0, 0.1f) || test_sdpa_int8(RandomMat(28, 17, 15), RandomMat(28, 32, 5), RandomMat(11, 32, 5), 1, -0.4f); } + +static int test_sdpa_int8_large_dim() +{ + if (!getenv("NCNN_TEST_SDPA_LARGE_DIM")) + return 0; + + return 0 + || test_sdpa_int8(RandomMat(4096, 16, 32), RandomMat(4096, 16, 1), RandomMat(4096, 16, 1), 0, 1.f / 64.f) + || test_sdpa_int8(RandomMat(4096, 16, 32), RandomMat(4096, 16, 4), RandomMat(4096, 16, 4), 0, 1.f / 64.f) + || test_sdpa_int8(RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), RandomMat(4096, 32, 32), 0, 1.f / 64.f); +} #endif int main() @@ -124,7 +135,7 @@ int main() SRAND(7767517); #if NCNN_INT8 - return test_sdpa_0() || test_sdpa_1() || test_sdpa_large_dim(); + return test_sdpa_0() || test_sdpa_1() || test_sdpa_large_dim() || test_sdpa_int8_large_dim(); #else return test_sdpa_0() || test_sdpa_large_dim(); #endif From 32c0e6ba9b39b37b4d84ac4d93ffebf0be275069 Mon Sep 17 00:00:00 2001 From: futz12 <56149058+futz12@users.noreply.github.com> Date: Wed, 3 Jun 2026 02:22:21 +0000 Subject: [PATCH 53/53] apply code-format changes --- src/layer/x86/sdpa_x86.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layer/x86/sdpa_x86.cpp b/src/layer/x86/sdpa_x86.cpp index 2d67fbd7c210..ff26193748e2 100644 --- a/src/layer/x86/sdpa_x86.cpp +++ b/src/layer/x86/sdpa_x86.cpp @@ -6265,8 +6265,8 @@ int SDPA_x86::forward(const std::vector& bottom_blobs, std::vector& to #if __AVX512F__ #if NCNN_RUNTIME_CPU && NCNN_AVX512BF16 && !__AVX512BF16__ use_qbf16_prefill = use_bf16_path && src_seqlen > 1 && embed_dim > 512 - && (src_seqlen > 16 || num_heads_per_group > 1) - && ncnn::cpu_support_x86_avx512_bf16(); + && (src_seqlen > 16 || num_heads_per_group > 1) + && ncnn::cpu_support_x86_avx512_bf16(); #endif #endif Mat query_fp32;