File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed
Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -642,8 +642,8 @@ static __global__ void flash_attn_stream_k_fixup(
642642 const int iter_k = (ne11 + (nbatch_fa - 1 )) / nbatch_fa;
643643 const int iter_j = (ne01 + (ncols1 - 1 )) / ncols1;
644644
645- const int kbc0 = (bidx0 + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
646- const int kbc0_stop = (bidx0 + 1 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
645+ const int kbc0 = int64_t (bidx0 + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
646+ const int kbc0_stop = int64_t (bidx0 + 1 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
647647
648648 const bool did_not_have_any_data = kbc0 == kbc0_stop;
649649 const bool wrote_beginning_of_tile = kbc0 % iter_k == 0 ;
@@ -679,7 +679,7 @@ static __global__ void flash_attn_stream_k_fixup(
679679 int bidx = bidx0 - 1 ;
680680 int kbc_stop = kbc0;
681681 while (true ) {
682- const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
682+ const int kbc = int64_t ( bidx) *(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
683683 if (kbc == kbc_stop) { // Did not have any data.
684684 bidx--;
685685 kbc_stop = kbc;
Original file line number Diff line number Diff line change @@ -1380,8 +1380,8 @@ static __global__ void flash_attn_ext_f16(
13801380 const int iter_j = (ne01.z + (ncols1 - 1 )) / ncols1;
13811381
13821382 // kbc == k block continuous, current index in continuous ijk space.
1383- int kbc = (blockIdx .x + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
1384- const int kbc_stop = (blockIdx .x + 1 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
1383+ int kbc = int64_t (blockIdx .x + 0 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
1384+ const int kbc_stop = int64_t (blockIdx .x + 1 )*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim .x ;
13851385
13861386 // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
13871387 // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1401,7 +1401,7 @@ static __global__ void flash_attn_ext_f16(
14011401 const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
14021402 const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
14031403 const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1404- (const half *) (mask + nb33*(sequence % ne33));
1404+ (const half *) (mask + nb33*(sequence % ne33));
14051405 float2 * dstk = ((float2 *) dst) + (sequence*ne01.z *ne02 + head0) * (DV/2 );
14061406
14071407 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
You can’t perform that action at this time.
0 commit comments