Skip to content

Commit 4dca341

Browse files
authored
fix: restrict raising dot_general to syrk to 2D tensors (#1666)
1 parent c8ab3b7 commit 4dca341

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25884,6 +25884,14 @@ struct DotGeneralToSyrk
2588425884
auto lhs = op.getLhs();
2588525885
auto rhs = op.getRhs();
2588625886

25887+
auto lhsType = cast<RankedTensorType>(lhs.getType());
25888+
auto rhsType = cast<RankedTensorType>(rhs.getType());
25889+
auto outType = cast<RankedTensorType>(op.getResult().getType());
25890+
if (lhsType.getRank() != 2 || rhsType.getRank() != 2 ||
25891+
outType.getRank() != 2) {
25892+
return failure();
25893+
}
25894+
2588725895
if (dotDims.getLhsBatchingDimensions().size() != 0 ||
2588825896
dotDims.getRhsBatchingDimensions().size() != 0) {
2588925897
return failure();

test/lit_tests/dotgeneral_to_syrk.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,15 @@ func.func @main4(%arg0: tensor<64x32xf32>) -> tensor<64x64xf32> {
7171
// CHECK-NEXT: %0 = enzymexla.blas.syrk %arg0, %cst_1, %cst_0, %cst {fill, uplo = #enzymexla.uplo<F>} : (tensor<64x32xf32>, tensor<64x64xf32>, tensor<f32>, tensor<f32>) -> tensor<64x64xf32>
7272
// CHECK-NEXT: return %0 : tensor<64x64xf32>
7373
// CHECK-NEXT: }
74+
75+
func.func @fail1(%arg0: tensor<5x2xf32>) -> tensor<f32> {
76+
%0 = stablehlo.reshape %arg0 : (tensor<5x2xf32>) -> tensor<10xf32>
77+
%1 = stablehlo.dot_general %0, %0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10xf32>, tensor<10xf32>) -> tensor<f32>
78+
return %1 : tensor<f32>
79+
}
80+
81+
// CHECK: func.func @fail1(%arg0: tensor<5x2xf32>) -> tensor<f32> {
82+
// CHECK-NEXT: %0 = stablehlo.reshape %arg0 : (tensor<5x2xf32>) -> tensor<10xf32>
83+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<10xf32>, tensor<10xf32>) -> tensor<f32>
84+
// CHECK-NEXT: return %1 : tensor<f32>
85+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)