Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class FMHAPrefill {
}

// Find the length of the longest non masked sequence within that subgroup
int calculate_longest_non_masked_length(
CUTLASS_DEVICE int calculate_longest_non_masked_length(
const int& seq_len_kv,
const int& seq_len_qo,
const int& last_seq_coord,
Expand Down Expand Up @@ -222,7 +222,7 @@ class FMHAPrefill {
}

template <class Tensor>
void handle_corner_cases(
CUTLASS_DEVICE void handle_corner_cases(
Tensor& tSr,
const int& thread_idx,
const int& SubgroupSize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1461,18 +1461,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> flash_attention_backward_sycltla(
.get_info<
sycl::ext::oneapi::experimental::info::device::architecture>();
constexpr auto supported_architectures =
std::array<sycl::ext::oneapi::experimental::architecture, 4>{
std::array<sycl::ext::oneapi::experimental::architecture, 3>{
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc,
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg,
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21,
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31};
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21};
if (std::find(
supported_architectures.begin(),
supported_architectures.end(),
device_architecture) == supported_architectures.end()) {
TORCH_CHECK(
false,
"XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31.");
"XPU device architecture does not support flash attention backward. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21.");
}

auto grad_query = at::empty_like(query);
Expand Down
31 changes: 15 additions & 16 deletions src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,19 +333,19 @@ void run_mha_fwd_(
TileShapeOutPut,
SubgroupLayout,
PipelineStages);
} else {
constexpr int PipelineStages = 2;
using TileShapeQK = Shape<_256, _32, _64>;
using TileShapePV = Shape<_256, _32, _32>;
using TileShapeOutPut = Shape<_256, _128, _32>;
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
run_mha_fwd_specialized(
TileShapeQK,
TileShapePV,
TileShapeOutPut,
SubgroupLayout,
PipelineStages);
}

constexpr int PipelineStages = 2;
using TileShapeQK = Shape<_256, _32, _64>;
using TileShapePV = Shape<_256, _32, _32>;
using TileShapeOutPut = Shape<_256, _128, _32>;
using SubgroupLayout = Layout<Shape<_16, _1, _1>, Stride<_1, _1, _1>>;
run_mha_fwd_specialized(
TileShapeQK,
TileShapePV,
TileShapeOutPut,
SubgroupLayout,
PipelineStages);
} else if (headdim == 192) {
constexpr int PipelineStages = 2;
using TileShapeQK = Shape<_256, _64, _64>;
Expand Down Expand Up @@ -537,18 +537,17 @@ flash_attention_forward_sycltla(
.get_info<
sycl::ext::oneapi::experimental::info::device::architecture>();
constexpr auto supported_architectures =
std::array<sycl::ext::oneapi::experimental::architecture, 4>{
std::array<sycl::ext::oneapi::experimental::architecture, 3>{
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc,
sycl::ext::oneapi::experimental::architecture::intel_gpu_pvc_vg,
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21,
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g31};
sycl::ext::oneapi::experimental::architecture::intel_gpu_bmg_g21};
if (std::find(
supported_architectures.begin(),
supported_architectures.end(),
device_architecture) == supported_architectures.end()) {
TORCH_CHECK(
false,
"XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21, intel_gpu_bmg_g31.");
"XPU device architecture does not support flash attention. Supported architectures are: intel_gpu_pvc, intel_gpu_pvc_vg, intel_gpu_bmg_g21.");
}

auto problem_shape = ProblemShapeRegular(
Expand Down
Loading