diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/cm_sdpa_vlen.cm b/src/plugins/intel_gpu/src/graph/impls/cm/cm_sdpa_vlen.cm index f06979e72dc50d..d514d1157ac9cb 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/cm_sdpa_vlen.cm +++ b/src/plugins/intel_gpu/src/graph/impls/cm/cm_sdpa_vlen.cm @@ -1,4 +1,3 @@ - /******************************************************************************* * Copyright (c) 2022-2025 Intel Corporation * @@ -15,7 +14,7 @@ * limitations under the License. *******************************************************************************/ namespace KERNEL_NAME { - + #include "cm_sdpa_common.hpp" #ifdef CM_HAS_LSC_UNTYPED_2D @@ -138,6 +137,7 @@ _GENX_MAIN_ void KERNEL_NAME( reinterpret_cast(output + qo_offset)); #endif #else +#if CMFLA_USE_STUB_KERNEL == 0 sdpa_kernel( slm_K, slm_V, @@ -156,7 +156,29 @@ _GENX_MAIN_ void KERNEL_NAME( kv_offset * sizeof(half), qo_offset * sizeof(half) ); +#else + sdpa_kernel_mma( + slm_K, + slm_V, + wg_local_id, + local_size, + 0, //q_start, + kv_seq_len, //kv_stop, + q_len, //q_len, + kv_seq_len, //kv_len, + query, + key, + value, + output, + qo_offset * sizeof(half), + kv_offset * sizeof(half), + kv_offset * sizeof(half), + qo_offset * sizeof(half) + ); + + + +#endif #endif } - } // NAMESPACE diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp index 325cd25fc5ba59..d080ffcb76f4f5 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_attention_common.hpp @@ -38,7 +38,6 @@ constexpr float scale_factor = CMFLA_SCALE_FACTOR; static_assert(q_step == 16 || q_step == 8); static_assert(kv_step == 16); -static_assert(CM_HAS_DPAS); #define DEBUG_SHOW 1 #if !DEBUG_SHOW diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp index 402cacb2e77674..e401e00a1fecea 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/include/cm_sdpa_common.hpp @@ -13,258 +13,378 @@ * See the License for the specific language governing permissions and * limitations under the License. *******************************************************************************/ -#include "cm_attention_common.hpp" +#include +#include + +//# CM-compiler is C++17 +static_assert(__cplusplus >= 201703L); + +#define SystolicDepth 8 +#define RepeatCount 8 +#define VNNI_WIDTH 2 +#define REG_K (SystolicDepth * VNNI_WIDTH) +#define REG_M RepeatCount +//REG_N +// Xe1: 8 +// Xe2: 16 +#define REG_N (CM_GRF_WIDTH/32) + +#define kv_step REG_K +#define q_step REG_N + +constexpr float scale_factor = CMFLA_SCALE_FACTOR; + +static_assert(q_step == 16 || q_step == 8); +static_assert(kv_step == 16); +//static_assert(CM_HAS_DPAS); + +template +void show(const matrix mat) { + printf("Matrix [%d, %d]:\n", M, N); + for(int m = 0; m < M; m ++) { + printf("\t["); + for(int n = 0; n < N; n ++) { + printf("%8.4f,", mat[m][n]); + } + printf("],\n"); + } + printf("]\n"); +} -#ifdef CM_HAS_LSC_UNTYPED_2D -//@prefetch_u8 would have duplicated decompress perf issue. comments out for now. -// template -// void sdpa_kernel_lsc_prefetch_u8( -// int wg_local_id, -// int q_start, -// int kv_stop, // -// int q_len, //q_step -// int kv_len, //not used for now -// svmptr_t q_base [[type("svmptr_t")]], -// svmptr_t k_cache_base [[type("svmptr_t")]], -// svmptr_t v_cache_base [[type("svmptr_t")]], -// svmptr_t o_base [[type("svmptr_t")]], -// int32_t past_lens, -// int32_t* block_indices [[type("svmptr_t")]]) { -// constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); -// constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads*2) * head_size * sizeof(half)) : o_pitch; -// //[block_num, kv_heads, block_size, head_size] -// constexpr uint kv_pitch = head_size * sizeof(uint8_t); - -// vector cur_max; -// vector cur_sum; - -// cur_max = -3e38f; -// cur_sum = 0; -// constexpr int num_P_tiles = REG_N / REG_M; -// matrix rQ; -// matrix rO; - -// auto q_tokens_left = q_len;// - q_start; -// static_assert(q_step == REG_N); -// static_assert(kv_step == REG_K); - -// if (q_tokens_left < 0) q_tokens_left = 0; -// if (q_tokens_left > q_step) q_tokens_left = q_step; - -// if (q_tokens_left > 0) { -// lsc::block_2d_desc b2dQ(reinterpret_cast(q_base), q_tokens_left - 1, head_size*sizeof(half) - 1, q_pitch - 1, 0, 0); -// #pragma unroll -// for(int k = 0, ri = 0; k < head_size/2; k += REG_K/2, ri++) { -// cm_load(rQ[ri].format(), b2dQ.set_block_x(k)); -// rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); -// } -// } - -// lsc::block_2d_desc b2dKV(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); - -// static_assert(wg_local_size == 16); -// lsc::block_2d_desc b2dKV_prefetch(k_cache_base, CMPA_BLOCK_SZ - 1, head_size*sizeof(uint8_t) - 1, kv_pitch - 1, 0, 0); -// // constexpr int blk_stride = CMFLA_NUM_KV_HEADS * CMFLA_HEAD_SIZE * CMPA_BLOCK_SZ; -// constexpr int quan_blk_stride = CMFLA_NUM_KV_HEADS * (CMFLA_HEAD_SIZE+4) * CMPA_BLOCK_SZ * sizeof(uint8_t); - - - -// int causal_left = q_start+past_lens; -// for(int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step) { -// auto cur_block_id = block_indices[kv_pos / CMPA_BLOCK_SZ]; -// //For the last step, duplicate prefetch here. -// uint32_t prefetch_kv_pos = (kv_pos+kv_step) >= kv_stop ? kv_pos : (kv_pos+kv_step); -// auto prefetch_block_id = block_indices[prefetch_kv_pos / CMPA_BLOCK_SZ]; -// uint32_t dscale_offset = cur_block_id*quan_blk_stride + CMPA_BLOCK_SZ * head_size * sizeof(uint8_t) + kv_pos%CMPA_BLOCK_SZ*sizeof(half); - -// vector dscale; -// vector zp; -// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset), dscale); -// cm_svm_block_read(reinterpret_cast( k_cache_base + dscale_offset + CMPA_BLOCK_SZ*sizeof(half)), zp); - -// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { -// // show(dscale.format()); -// // } -// //# St = k @ Qt - -// matrix St; // = ugemm_KQ(slm_K, rQ, slm_offset); -// { -// constexpr int num_K = kv_step/REG_M; -// auto St2 = St.format(); -// matrix Kmat; -// auto quan_Kmat = Kmat.format().row(1).format(); -// auto dq_Kmat = Kmat.format(); -// //cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); - -// b2dKV_prefetch.set_base_ptr(reinterpret_cast(k_cache_base+prefetch_block_id*quan_blk_stride)); -// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); -// cm_prefetch(b2dKV_prefetch.set_block_x(0)); - -// b2dKV.set_base_ptr(reinterpret_cast(k_cache_base+cur_block_id*quan_blk_stride)); -// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); - -// cm_load(quan_Kmat.format(), b2dKV.set_block_x(0)); -// // if (cm_local_id(2) == 0 && cm_group_id(2) == 0) { -// // show(quan_Kmat.format(), false); -// // } -// #pragma unroll -// for(int r = 0; r < kv_step; r++) { -// dq_Kmat[r] = quan_Kmat[r] - zp[r]; -// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); -// } - -// #pragma unroll -// for(int k = 0; k < num_K; k++) -// St2.row(k) = cm_dpas( -// 0, -// rQ[0].format(), -// Kmat[k].format()); - -// #pragma unroll -// for(int ri = 1; ri < head_size/REG_K; ri++) { -// cm_prefetch(b2dKV_prefetch.set_block_x(ri*REG_K)); -// //cm_load(Kmat.format(), b2dKV.set_block_x(ri*REG_K)); -// cm_load(quan_Kmat.format(), b2dKV.set_block_x(ri*REG_K)); -// #pragma unroll -// for(int r = 0; r < kv_step; r++) { -// dq_Kmat[r] = quan_Kmat[r] - zp[r]; -// dq_Kmat[r] = cm_mul(dq_Kmat[r], dscale[r]); -// } -// #pragma unroll -// for(int k = 0; k < num_K; k++) { -// St2.row(k) = cm_dpas( -// St2.row(k), -// rQ[ri].format(), -// Kmat[k].format()); -// } -// } -// } -// if constexpr (use_causal_mask) { -// // since kv_step == q_step == 16, causal_left is n*kv_step -// if (causal_left == 0) { -// apply_causal_mask<1>(St); -// } else if (causal_left < 0) { -// St = -3.4e38f; -// } -// causal_left -= kv_step; -// } else { -// int kv_tokens = kv_stop - kv_pos; -// // LSC ensures no overflow-access, but mask off k-tails attn-score is still required -// for(int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; -// } - -// //show(St); -// auto max_comp = online_softmax_update(St, cur_max, cur_sum); - -// matrix P; -// Transpose2DMatrix(St, P); - -// b2dKV_prefetch.set_base_ptr(reinterpret_cast(v_cache_base+prefetch_block_id*quan_blk_stride)); -// b2dKV_prefetch.set_block_y((prefetch_kv_pos + wg_local_id) % CMPA_BLOCK_SZ); - -// b2dKV.set_base_ptr(reinterpret_cast(v_cache_base+cur_block_id*quan_blk_stride)); -// b2dKV.set_block_y(kv_pos%CMPA_BLOCK_SZ); - - -// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset), dscale); -// cm_svm_block_read(reinterpret_cast(v_cache_base+dscale_offset+CMPA_BLOCK_SZ*sizeof(half)), zp); - -// { -// matrix VmatVNNI2; -// matrix Vmat; -// auto quanVmat = Vmat.format().row(1).format(); -// int kv_tokens = kv_stop - kv_pos; -// if (kv_pos == 0) { -// // ugemm_PV0(slm_V, P, rO, slm_offset); -// auto P2 = P.format(); -// #pragma unroll -// for(int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { -// cm_prefetch(b2dKV_prefetch.set_block_x(k)); -// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); -// #pragma unroll -// for(int r = 0; r < kv_step;r++) { -// Vmat[r] = quanVmat[r]-zp[r]; -// Vmat[r] = cm_mul(Vmat[r], dscale[r]); -// } -// for(int r = kv_step-1; r>=kv_tokens;r--) { -// Vmat[r] = 0; -// } - -// prepackAsVNNIWidth2(Vmat, VmatVNNI2); - -// #pragma unroll -// for(int p = 0; p < num_P_tiles; p++) { -// rO[ri + p] = cm_dpas( -// 0, -// VmatVNNI2.format(), -// P2.row(p).format()); -// } -// } -// } -// else { -// //ugemm_PV1(slm_V, P, max_comp, rO, slm_offset); -// auto P2 = P.format(); -// #pragma unroll -// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { -// cm_prefetch(b2dKV_prefetch.set_block_x(k)); -// cm_load(quanVmat.format(), b2dKV.set_block_x(k)); - -// #pragma unroll -// for(int r = 0; r < kv_step;r++) { -// Vmat[r] = quanVmat[r]-zp[r]; -// Vmat[r] = cm_mul(Vmat[r], dscale[r]); -// } -// for(int r = kv_step-1; r>=kv_tokens;r--) { -// Vmat[r] = 0; -// } - -// prepackAsVNNIWidth2(Vmat, VmatVNNI2); -// //# compensate cur_O -// // matrix rO; -// #pragma unroll -// for(int p = 0; p < num_P_tiles; p++) { -// auto cO = rO[ri + p].format(); -// #pragma unroll -// for(int r = 0; r < REG_M; r++) -// cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); -// } - -// #pragma unroll -// for(int p = 0; p < num_P_tiles; p++) { -// rO[ri + p] = cm_dpas( -// rO[ri + p].format(), -// VmatVNNI2.format(), -// P2.row(p).format()); -// } -// } -// } -// } -// } -// if (q_tokens_left == 0) return; - -// //# save cur_O/cur_sum.transpose(0, 1) -// matrix cur_O_f16; -// cur_sum = cm_inv(cur_sum); - -// lsc::block_2d_desc b2dO(o_base, q_tokens_left - 1, head_size*sizeof(half) - 1, o_pitch - 1, 0, 0); - -// #pragma unroll -// for(int k = 0, ri=0; k < head_size; k += REG_N, ri += num_P_tiles) { -// #pragma unroll -// for(int p = 0; p < num_P_tiles; p++) { -// auto cO = rO[ri + p].format(); -// #pragma unroll -// for(int r = 0; r < cO.n_rows(); r++) { -// cur_O_f16[r + p*REG_M] = cm_mul(cO.row(r), cur_sum[r + p*REG_M]); - -// } -// } -// b2dO.set_block_x(k); -// cm_store(b2dO.set_block_y(0), cur_O_f16.format().row(0)); -// cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); -// } -// } +template +CM_INLINE void Transpose_16x16(matrix_ref in, + matrix_ref out) { + matrix bBuf; + bBuf.row(0) = in.template select<4, 1, 4, 4>(0, 0); // 0,4,8,c + bBuf.row(1) = in.template select<4, 1, 4, 4>(4, 0); // 0,4,8,c + bBuf.row(2) = in.template select<4, 1, 4, 4>(8, 0); // 0,4,8,c + bBuf.row(3) = in.template select<4, 1, 4, 4>(12, 0); // 0,4,8,c + bBuf.row(4) = in.template select<4, 1, 4, 4>(0, 1); // 1,5,9,d + bBuf.row(5) = in.template select<4, 1, 4, 4>(4, 1); // 1,5,9,d + bBuf.row(6) = in.template select<4, 1, 4, 4>(8, 1); // 1,5,9,d + bBuf.row(7) = in.template select<4, 1, 4, 4>(12, 1); // 1,5,9,d + bBuf.row(8) = in.template select<4, 1, 4, 4>(0, 2); // 2,6,a,e + bBuf.row(9) = in.template select<4, 1, 4, 4>(4, 2); // 2,6,a,e + bBuf.row(10) = in.template select<4, 1, 4, 4>(8, 2); // 2,6,a,e + bBuf.row(11) = in.template select<4, 1, 4, 4>(12, 2); // 2,6,a,e + bBuf.row(12) = in.template select<4, 1, 4, 4>(0, 3); // 3,7,b,f + bBuf.row(13) = in.template select<4, 1, 4, 4>(4, 3); // 3,7,b,f + bBuf.row(14) = in.template select<4, 1, 4, 4>(8, 3); // 3,7,b,f + bBuf.row(15) = in.template select<4, 1, 4, 4>(12, 3); // 3,7,b,f + + out.row(0) = bBuf.template select<4, 1, 4, 4>(0, 0); // 0 + out.row(1) = bBuf.template select<4, 1, 4, 4>(4, 0); // 1 + out.row(2) = bBuf.template select<4, 1, 4, 4>(8, 0); // 2 + out.row(3) = bBuf.template select<4, 1, 4, 4>(12, 0); // 3 + out.row(4) = bBuf.template select<4, 1, 4, 4>(0, 1); // 4 + out.row(5) = bBuf.template select<4, 1, 4, 4>(4, 1); // 5 + out.row(6) = bBuf.template select<4, 1, 4, 4>(8, 1); // 6 + out.row(7) = bBuf.template select<4, 1, 4, 4>(12, 1); // 7 + out.row(8) = bBuf.template select<4, 1, 4, 4>(0, 2); // 8 + out.row(9) = bBuf.template select<4, 1, 4, 4>(4, 2); // 9 + out.row(10) = bBuf.template select<4, 1, 4, 4>(8, 2); // a + out.row(11) = bBuf.template select<4, 1, 4, 4>(12, 2); // b + out.row(12) = bBuf.template select<4, 1, 4, 4>(0, 3); // c + out.row(13) = bBuf.template select<4, 1, 4, 4>(4, 3); // d + out.row(14) = bBuf.template select<4, 1, 4, 4>(8, 3); // e + out.row(15) = bBuf.template select<4, 1, 4, 4>(12, 3); // f +} +template +CM_INLINE void Transpose_8x8(matrix_ref in, matrix_ref out) { + matrix temp; + temp.row(0) = in.template select<2, 1, 4, 2>(0, 0); + temp.row(1) = in.template select<2, 1, 4, 2>(2, 0); + temp.row(2) = in.template select<2, 1, 4, 2>(4, 0); + temp.row(3) = in.template select<2, 1, 4, 2>(6, 0); + temp.row(4) = in.template select<2, 1, 4, 2>(0, 1); + temp.row(5) = in.template select<2, 1, 4, 2>(2, 1); + temp.row(6) = in.template select<2, 1, 4, 2>(4, 1); + temp.row(7) = in.template select<2, 1, 4, 2>(6, 1); + + out.row(0) = temp.template select<4, 1, 2, 4>(0, 0); + out.row(2) = temp.template select<4, 1, 2, 4>(0, 1); + out.row(4) = temp.template select<4, 1, 2, 4>(0, 2); + out.row(6) = temp.template select<4, 1, 2, 4>(0, 3); + out.row(1) = temp.template select<4, 1, 2, 4>(4, 0); + out.row(3) = temp.template select<4, 1, 2, 4>(4, 1); + out.row(5) = temp.template select<4, 1, 2, 4>(4, 2); + out.row(7) = temp.template select<4, 1, 2, 4>(4, 3); +} + +// function templates cannot be partially specialized; use overloading to achieve the same effect +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_16x16(in, out); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(8,0), out.select<8, 1, 8, 1>(0,8)); +} +template +inline void Transpose2DMatrix(matrix_ref in, matrix_ref out) { + Transpose_8x8(in.select<8, 1, 8, 1>(0,0), out.select<8, 1, 8, 1>(0,0)); + Transpose_8x8(in.select<8, 1, 8, 1>(0,8), out.select<8, 1, 8, 1>(8,0)); +} + +template +CM_INLINE void slm_read_2d(matrix_ref out, uint slm, int offset) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_slm_block_read(slm, GENX_DWALIGNED, offset + i*n_stride*sizeof(T), out.row(i)); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + i * pitch, out[i]); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_load_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + out.row(i).format() = cm_load(base, offset + i * pitch); + } +} + +template +CM_INLINE void cm_store_2d(matrix_ref out, SurfaceIndex base, uint offset, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_store(base, offset + i * pitch, out.row(i).format()); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, vector_ref offsets) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++) { + cm_svm_block_read(base + offsets[i], out[i]); + } +} + +template +CM_INLINE void svm_read_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch, n_rows--) { + if (n_rows > 0) cm_svm_block_read(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + cm_svm_block_write(base, out[i]); + } +} + +template +CM_INLINE void svm_write_2d(matrix_ref out, svmptr_t base, uint pitch, int n_rows) { + #pragma unroll + for(int i = 0; i < out.n_rows(); i++, base += pitch) { + if (i < n_rows) cm_svm_block_write(base, out[i]); + } +} + +CM_INLINE uint64_t get_clock() { + auto clk = cm_clock(); + return ((uint64_t)clk[1]) << 32 | clk[0]; +} + + +template +inline matrix ugemm_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { + matrix St; + constexpr int num_K = _kv_step/REG_M; + auto St2 = St.format(); + + matrix Kmat; + cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + #pragma unroll + for(int k = 0; k < num_K; k++) + St2.row(k) = cm_dpas(0, Qt[0].format(), Kmat[k].format()); + + #pragma unroll + for(int ri = 1; ri < num_Qt; ri++) { + cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); + #pragma unroll + for(int k = 0; k < num_K; k++) { + St2.row(k) = cm_dpas(St2.row(k), Qt[ri].format(), Kmat[k].format()); + } + } + return St; +} + +template +inline void ugemm_PV0(uint slm_V, matrix_ref P, matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri = 0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + 0, + Vmat.format(), + P2.row(p).format()); + //show(rO[ri + p].format()); + } + } +} + +template +inline void ugemm_PV1(uint slm_V, matrix_ref P, vector_ref max_comp, + matrix_ref rO, uint slm_offset = 0) { + constexpr int _head_size = num_rO_tiles*REG_N/num_P_tiles; + auto P2 = P.format(); + #pragma unroll + for(int k = 0, ri=0; k < _head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K*k*sizeof(half), Vmat.format()); + + //# compensate cur_O + // matrix rO; + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); + #pragma unroll + for(int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p*REG_M]); + } + + //show(rO[ri].format()); + + //# show(cur_O.format()); return; + #pragma unroll + for(int p = 0; p < num_P_tiles; p++) { + rO[ri + p] = cm_dpas( + rO[ri + p].format(), + Vmat.format(), + P2.row(p).format()); + //if (kv_pos == args_verbose) show(rO[ri + p].format()); + } + // if (kv_pos == args_verbose) show(cur_O.format()); + } +} + +template +inline matrix mma_KQ(uint slm_K, matrix_ref Qt, uint slm_offset = 0) { + matrix St; + constexpr int num_K = _kv_step / REG_M; +// auto St2 = St.format(); + auto St2 = St.format(); + St2 = 0; + matrix Kmat; + cm_slm_block_read(slm_K, GENX_NONE, slm_offset, Kmat.format()); + vector QmatFp32; + vector KmatFp32; +#pragma unroll + for (int k = 0; k < num_K; k++) { + KmatFp32 = Kmat.row(k); + QmatFp32 = Qt.row(0); +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + St2.select(k * REG_M * REG_N + nCol * REG_N) += + QmatFp32.select(depth * REG_N * 2 + 0) * KmatFp32[nCol * REG_K + 2 * depth]; + } + } + +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + St2.select(k * REG_M * REG_N + nCol * REG_N) += + QmatFp32.select(depth * REG_N * 2 + 1) * KmatFp32[nCol * REG_K + 2 * depth + 1]; + } + } + } + +#pragma unroll + for (int ri = 1; ri < num_Qt; ri++) { + QmatFp32 = Qt.row(ri); + cm_slm_block_read(slm_K, GENX_NONE, slm_offset + ri * Kmat.n_elems() * sizeof(half), Kmat.format()); +#pragma unroll + for (int k = 0; k < num_K; k++) { + KmatFp32 = Kmat.row(k); +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + St2.select(k * REG_M * REG_N + nCol * REG_N) += + QmatFp32.select(depth * REG_N * 2 + 0) * KmatFp32[nCol * REG_K + 2 * depth]; + } + } + +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + St2.select(k * REG_M * REG_N + nCol * REG_N) += + QmatFp32.select(depth * REG_N * 2 + 1) * KmatFp32[nCol * REG_K + 2 * depth + 1]; + } + } + } + } + return St; +} + +template +vector online_softmax_update(matrix_ref St, vector_ref cur_max, vector_ref cur_sum) { + vector new_max_t; + new_max_t = cm_max(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) new_max_t = cm_max(new_max_t, St[r]); + new_max_t = cm_max(new_max_t, cur_max); + + // Pt = torch.exp(St - new_max) + constexpr float log2e = 1.4426950408889634f; + for(int r = 0; r < St.n_rows(); r++) St[r] = cm_exp((St[r] - new_max_t)*log2e); + + vector row_sum_t; + row_sum_t = cm_add(St[0], St[1]); + for(int r = 2; r < St.n_rows(); r++) row_sum_t = cm_add(row_sum_t, St[r]); + + vector max_comp; + max_comp = cm_exp((cur_max - new_max_t)*log2e); + cur_sum = cm_mul(cur_sum, max_comp); + cur_sum = cm_add(cur_sum, row_sum_t); + cur_max = new_max_t; + return max_comp; +} + +//=============================================================================================== +template +constexpr void apply_causal_mask(matrix_ref St) { + if constexpr (i < N) { + St.row(i).select(0) = -3.4e38f; + apply_causal_mask(St); + } +} + +#ifdef CM_HAS_LSC_UNTYPED_2D template void sdpa_kernel_lsc( uint slm_K, @@ -425,6 +545,7 @@ void sdpa_kernel_lsc( } } + template void sdpa_kernel_lsc_prefetch( int wg_local_id, @@ -604,8 +725,7 @@ void sdpa_kernel_lsc_prefetch( cm_store(b2dO.set_block_y(REG_M), cur_O_f16.format().row(1)); } } - -#else // CM_HAS_LSC_UNTYPED_2D +#endif template void sdpa_kernel( @@ -809,4 +929,278 @@ void sdpa_kernel( } } -#endif // !CM_HAS_LSC_UNTYPED_2D \ No newline at end of file +template +void CM_INLINE sdpa_kernel_mma( + uint slm_K, + uint slm_V, + int wg_local_id, + int local_size, + int q_start, + int kv_stop, + int q_len, + int kv_len, + SurfaceIndex query [[type("buffer_t")]], + SurfaceIndex key [[type("buffer_t")]], + SurfaceIndex value [[type("buffer_t")]], + SurfaceIndex output [[type("buffer_t")]], + uint q_off, + uint k_off, + uint v_off, + uint o_off) { + + constexpr uint o_pitch = (num_heads * head_size * sizeof(half)); + constexpr uint q_pitch = is_qkv_fused ? ((num_heads + num_kv_heads * 2) * head_size * sizeof(half)) : o_pitch; + constexpr uint kv_pitch = is_qkv_fused ? q_pitch : (num_kv_heads * head_size * sizeof(half)); + + vector cur_max; + vector cur_sum; + + cur_max = -3e38f; + cur_sum = 0; + + matrix rQ; + auto q_tokens_left = q_len; + static_assert(q_step == REG_N); + static_assert(kv_step == REG_K); + + if (q_tokens_left < 0) q_tokens_left = 0; + if (q_tokens_left > q_step) q_tokens_left = q_step; + + if (q_tokens_left > 0) { + // load as many as possible given one address + if constexpr (head_size == 128 || head_size == 64) { + matrix QmatI32; + cm_load_2d(QmatI32, query, q_off, q_pitch); +#pragma unroll + for (int k = 0, ri = 0; k < head_size / 2; k += REG_K / 2, ri++) { + Transpose2DMatrix(QmatI32.select(0, k), rQ[ri].format()); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + else { +#pragma unroll + for (int k = 0, ri = 0; k < head_size / 2; k += REG_K / 2, ri++) { + matrix QmatI32; + cm_load_2d(QmatI32, query, q_off + k * sizeof(uint), q_pitch); + Transpose2DMatrix(QmatI32, rQ[ri].format()); + rQ[ri].format() = cm_mul(rQ[ri].format(), (half)scale_factor); + } + } + } + + constexpr int num_P_tiles = REG_N / REG_M; + matrix rO; + int causal_left = q_start; + rO = 0.0f; + + constexpr uint slm_buff_size = kv_step * head_size * sizeof(half); + int slm_buff_id_write = 0; + int slm_buff_id_read = 0; + + auto load_slm_KV = [&](int kv_pos) { + //if (kv_pos < 1024000) return; + int kv_tokens = kv_stop - kv_pos; + if (kv_tokens <= 0) return; + uint slm_offset = (slm_buff_id_write & 3) * slm_buff_size; + slm_buff_id_write++; + + // non-tail branch is faster + if (wg_local_id < local_size / 2) { + //if (kv_pos > 1024000) { + matrix temp; + for (int k = REG_K * wg_local_id; k < head_size; k += REG_K * (local_size / 2)) { + cm_load_2d(temp, key, k_off + k * sizeof(half), kv_pitch); + cm_slm_block_write(slm_K, + slm_offset + k * 2 * REG_M * sizeof(half), + temp.format()); + } + } + else { + //if (kv_pos > 1024000) { + // read 16x16 XMX-B matrix (1x REG_N in Xe2, 2x REG_N in Xe1) + constexpr int VK_STEP = 16; + static_assert((VK_STEP % REG_N) == 0); + matrix temp2; + matrix temp_vnni; + //b2dV.set_block_y(kv_pos); + + static_assert((head_size % VK_STEP) == 0); +#pragma unroll + for (int k = VK_STEP * (wg_local_id - local_size / 2); k < head_size; k += VK_STEP * (local_size / 2)) { + cm_load_2d(temp2, value, v_off + k * sizeof(half), kv_pitch); + +#pragma unroll + for (int p = 0; p < VK_STEP / REG_N; p++) { + temp_vnni.select(0, 0) = temp2.select(0, p * REG_N); + temp_vnni.select(0, 1) = temp2.select(1, p * REG_N); + // show(temp_vnni); + cm_slm_block_write(slm_V, slm_offset + (k + p * REG_N) * REG_K * sizeof(half), temp_vnni.format()); + } + } + } + k_off += kv_step * kv_pitch; + v_off += kv_step * kv_pitch; + // printf(" diff= %lu\n", get_clock() - clk0); + }; + + load_slm_KV(0); + load_slm_KV(kv_step); + + cm_slm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(1); + + for (int kv_pos = 0; kv_pos < kv_stop; kv_pos += kv_step, + slm_buff_id_read++) { + // + // load0->0, signal1, + // [load1->1, wait2, signal2, read0] + // [load2->2, wait3, signal3, read1] + // [load3->3, wait4, signal4, read2] + // [load4->0, wait5, signal5, read3] + // + // after wait4, all workers have reached signal3, so: + // - all workers have finished load2 & read0. + // - we can start to load 4 into SLM slot 0 (i & 3) safely + // - we can start to read 2 ((i-2) & 3) safely + // + cm_fence(CM_LOCAL_BARRIER); + cm_sbarrier(0); + + load_slm_KV(kv_pos + 2 * kv_step); + + if (kv_pos + kv_step < kv_stop) + cm_sbarrier(1); + + //if (kv_pos < 1024000) continue; + uint slm_offset = (slm_buff_id_read & 3) * slm_buff_size; + + //=========================================================== 1807 ~ 3247 + //# St = k @ Qt + matrix St = mma_KQ(slm_K, rQ, slm_offset); + + if constexpr (use_causal_mask) { + if (causal_left < kv_step) { + vector cmask = 0.0f; + int p = causal_left + 1; + int v = 0; + for (; p < 0; p++) { + cmask[v] = -3.4e38f; + if (v < q_step - 1) v++; + } + for (; p < kv_step; p++) { + cmask[v] = -3.4e38f; + St[p] = cm_add(St[p], cmask); + if (v < q_step - 1) v++; + } + //if (wg_local_id == 0) show(St);return; + } + causal_left -= kv_step; + } + + // mask off k-tails + int kv_tokens = kv_stop - kv_pos; + for (int p = kv_tokens; p < kv_step; p++) St[p] = -3.4e38f; + + //show(St); + auto max_comp = online_softmax_update(St, cur_max, cur_sum); + + matrix P; + Transpose2DMatrix(St, P); + vector P3 = P.format(); + if (kv_pos == 0) { +// auto P3 = P.format(); + vector VmatFp32; +#pragma unroll + for (int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K * k * sizeof(half), Vmat.format()); + VmatFp32 = Vmat; +#pragma unroll + for (int p = 0; p < num_P_tiles; p++) { +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + rO.row(ri + p).select(nCol * REG_N) += + VmatFp32.select(depth * REG_N * 2 + 0) * P3[p * REG_M * REG_K + nCol * REG_K + 2 * depth]; + } + } + +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + rO.row(ri + p).select(nCol * REG_N) += + VmatFp32.select(depth * REG_N * 2 + 1) * P3[p * REG_M * REG_K + nCol * REG_K + 2 * depth + 1]; + } + } + //show(rO[ri + p].format()); + } + } + } + else { + vector VmatFp32; +#pragma unroll + for (int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { + matrix Vmat; + cm_slm_block_read(slm_V, GENX_NONE, slm_offset + REG_K * k * sizeof(half), Vmat.format()); + VmatFp32 = Vmat; + //# compensate cur_O + // matrix rO; +#pragma unroll + for (int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); +#pragma unroll + for (int r = 0; r < REG_M; r++) + cO.row(r) = cm_mul(cO.row(r), max_comp[r + p * REG_M]); + } + + //show(rO[ri].format()); + + //# show(cur_O.format()); return; +#pragma unroll + for (int p = 0; p < num_P_tiles; p++) { +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + rO.row(ri + p).select(nCol * REG_N) += + VmatFp32.select(depth * REG_N * 2 + 0) * P3[p * REG_M * REG_K + nCol * REG_K + 2 * depth]; + } + } + +#pragma unroll + for (int32_t depth = 0; depth < SystolicDepth; depth++) { +#pragma unroll + for (int32_t nCol = 0; nCol < REG_M; nCol++) { + rO.row(ri + p).select(nCol * REG_N) += + VmatFp32.select(depth * REG_N * 2 + 1) * P3[p * REG_M * REG_K + nCol * REG_K + 2 * depth + 1]; + } + } + //if (kv_pos == args_verbose) show(rO[ri + p].format()); + } + // if (kv_pos == args_verbose) show(cur_O.format()); + } + } + } + + if (q_tokens_left > 0) { + //# save cur_O/cur_sum.transpose(0, 1) + matrix cur_O_f16; + cur_sum = cm_inv(cur_sum); + +#pragma unroll + for (int k = 0, ri = 0; k < head_size; k += REG_N, ri += num_P_tiles) { +#pragma unroll + for (int p = 0; p < num_P_tiles; p++) { + auto cO = rO[ri + p].format(); +#pragma unroll + for (int r = 0; r < cO.n_rows(); r++) { + cur_O_f16[r + p * REG_M] = cm_mul(cO.row(r), cur_sum[r + p * REG_M]); + } + } + // if (i == args_verbose) show(cur_O_f16); + cm_store_2d(cur_O_f16, output, o_off + k * sizeof(half), o_pitch); + } + } +} diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.cpp b/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.cpp index 8127c86fa0c771..ab231ff2664485 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.cpp @@ -56,6 +56,7 @@ class VLSDPAGenerator : public KernelGenerator { const size_t num_q_heads = query_shape[query_shape.size() - 3].get_length(); const size_t num_kv_heads = key_shape[key_shape.size() - 3].get_length(); const float scale_factor = 1.0 / std::sqrt(static_cast(head_size)); + const bool use_stub_kernel = !params.get_device_info().supports_immad; GPU_DEBUG_TRACE_DETAIL << "VLSDPA query_shape " << query_shape << ", q_transpose_order " << PartialShape(desc->input_q_transpose_order) << ", key_shape " << key_shape << ", k_transpose_order " << PartialShape(desc->input_k_transpose_order) @@ -67,6 +68,7 @@ class VLSDPAGenerator : public KernelGenerator { make_jit_constant("CMFLA_NUM_KV_HEADS", num_kv_heads), make_jit_constant("CMFLA_HEAD_SIZE", head_size), make_jit_constant("CMFLA_SCALE_FACTOR", scale_factor), + make_jit_constant("CMFLA_USE_STUB_KERNEL", use_stub_kernel ? 1 : 0), }); return jit; diff --git a/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.hpp b/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.hpp index e6f7aaea039f95..999cda11558672 100644 --- a/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/cm/vl_sdpa_opt.hpp @@ -42,10 +42,9 @@ struct VLSDPAOptImplementationManager : public ImplementationManager { assert(node.is_type()); auto& engine = node.get_program().get_engine(); const auto& config = node.get_program().get_config(); - const auto& info = engine.get_device_info(); // CM optimized for systolic-array architectures - if (!check_cm_jit_support(engine, config) || !info.supports_immad || !config.get_use_cm()) { + if (!check_cm_jit_support(engine, config) || !config.get_use_cm()) { return false; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp index 6e70ba0cb12cb1..e18eab4c615e2a 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kernel_selector_helper.cpp @@ -87,7 +87,6 @@ bool check_cm_jit_support(cldnn::engine& e, const cldnn::ExecutionConfig& config // This program checks if cm sources can be jitted by current IGC version const char* kernel_code = R""""( static_assert(__cplusplus >= 201703L); - static_assert(CM_HAS_DPAS); CM_INLINE uint64_t dummy() { return ((uint64_t)0L); } diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index c5868e715b9a4e..2af66fa1f6c044 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -367,10 +367,6 @@ void TransformationsPipeline::apply(std::shared_ptr func) { pass_config->set_callback( [&](const_node_ptr &) -> bool { auto& engine = m_context->get_engine(); - const auto& info = engine.get_device_info(); - if (!(info.supports_immad)) { // CM optimized for systolic-array architectures - return true; - } #ifdef GPU_DEBUG_CONFIG if (!config.get_use_cm()) { diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/vlsdpa_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/vlsdpa_gpu_test.cpp index 01c73f19619a43..bbeb4ba9e93d53 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/vlsdpa_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/vlsdpa_gpu_test.cpp @@ -197,7 +197,7 @@ struct vlsdpa_gpu_test : public ::testing::TestWithParam { static bool check_vlsdpa_available() { auto& engine = get_test_engine(); ExecutionConfig config = get_test_default_config(engine); - if (!cldnn::check_cm_jit_support(engine, config) || !engine.get_device_info().supports_immad) { + if (!cldnn::check_cm_jit_support(engine, config)) { return false; }