Skip to content

Commit 4822114

Browse files
CUDA: fix overflow in MMA kernel without stream-k (#17939)
1 parent 7bed317 commit 4822114

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -642,8 +642,8 @@ static __global__ void flash_attn_stream_k_fixup(
642642
const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
643643
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
644644

645-
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
646-
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
645+
const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
646+
const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
647647

648648
const bool did_not_have_any_data = kbc0 == kbc0_stop;
649649
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -679,7 +679,7 @@ static __global__ void flash_attn_stream_k_fixup(
679679
int bidx = bidx0 - 1;
680680
int kbc_stop = kbc0;
681681
while(true) {
682-
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
682+
const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
683683
if (kbc == kbc_stop) { // Did not have any data.
684684
bidx--;
685685
kbc_stop = kbc;

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16(
13801380
const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
13811381

13821382
// kbc == k block continuous, current index in continuous ijk space.
1383-
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1384-
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1383+
int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1384+
const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
13851385

13861386
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
13871387
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1401,7 +1401,7 @@ static __global__ void flash_attn_ext_f16(
14011401
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14021402
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14031403
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1404-
(const half *) (mask + nb33*(sequence % ne33));
1404+
(const half *) (mask + nb33*(sequence % ne33));
14051405
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
14061406

14071407
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));

0 commit comments

Comments
 (0)