diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h index 7b903f4fc9..2fe2ee3eeb 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/kernel/xe_sdpa_fwd_bshd.h @@ -191,7 +191,7 @@ class FMHAPrefill { } // Find the length of the longest non masked sequence within that subgroup - int calculate_longest_non_masked_length( + CUTLASS_DEVICE int calculate_longest_non_masked_length( const int& seq_len_kv, const int& seq_len_qo, const int& last_seq_coord, @@ -222,7 +222,7 @@ class FMHAPrefill { } template - void handle_corner_cases( + CUTLASS_DEVICE void handle_corner_cases( Tensor& tSr, const int& thread_idx, const int& SubgroupSize, diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp index aa467b4ea9..5605292c77 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp @@ -1461,18 +1461,17 @@ std::tuple flash_attention_backward_sycltla( .get_info< sycl::ext::oneapi::experimental::info::device::architecture>(); constexpr auto supported_architectures = - std::array{ + std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21}; if (std::find( supported_architectures.begin(), supported_architectures.end(), device_architecture) == supported_architectures.end()) { TORCH_CHECK( false, - "XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + "XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); } auto grad_query = at::empty_like(query); diff --git a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp index 2ac153ad99..cda198b8e8 100644 --- a/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp +++ b/src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp @@ -333,19 +333,19 @@ void run_mha_fwd_( TileShapeOutPut, SubgroupLayout, PipelineStages); + } else { + constexpr int PipelineStages = 2; + using TileShapeQK = Shape<_256, _32, _64>; + using TileShapePV = Shape<_256, _32, _32>; + using TileShapeOutPut = Shape<_256, _128, _32>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; + run_mha_fwd_specialized( + TileShapeQK, + TileShapePV, + TileShapeOutPut, + SubgroupLayout, + PipelineStages); } - - constexpr int PipelineStages = 2; - using TileShapeQK = Shape<_256, _32, _64>; - using TileShapePV = Shape<_256, _32, _32>; - using TileShapeOutPut = Shape<_256, _128, _32>; - using SubgroupLayout = Layout, Stride<_1, _1, _1>>; - run_mha_fwd_specialized( - TileShapeQK, - TileShapePV, - TileShapeOutPut, - SubgroupLayout, - PipelineStages); } else if (headdim == 192) { constexpr int PipelineStages = 2; using TileShapeQK = Shape<_256, _64, _64>; @@ -537,18 +537,17 @@ flash_attention_forward_sycltla( .get_info< sycl::ext::oneapi::experimental::info::device::architecture>(); constexpr auto supported_architectures = - std::array{ + std::array{ sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc, sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21, - sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31}; + sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21}; if (std::find( supported_architectures.begin(), supported_architectures.end(), device_architecture) == supported_architectures.end()) { TORCH_CHECK( false, - "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31."); + "XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21."); } auto problem_shape = ProblemShapeRegular(