Skip to content

Commit a8ceb9c

Browse files
snonkavik-pal
authored andcommitted
add test and fmt
1 parent 81bad78 commit a8ceb9c

File tree

2 files changed

+42
-4
lines changed

2 files changed

+42
-4
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -710,14 +710,14 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710710
}
711711
}
712712

713-
// A(A^T) will always be symmetric
713+
// A(A^T) will always be symmetric
714714
if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
715715
auto lhs = dotOp.getLhs();
716716
auto rhs = dotOp.getRhs();
717717

718718
auto lhsDims = dotOp.getDotDimensionNumbers().getLhsContractingDimensions();
719719
auto rhsDims = dotOp.getDotDimensionNumbers().getRhsContractingDimensions();
720-
720+
721721
if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType());
722722
lhsType && lhsType.hasRank() && lhsType.getRank() == 2) {
723723

@@ -733,8 +733,8 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
733733
if (rhsInput == lhs && isTrueTranspose(rhsT))
734734
return State::GUARANTEED;
735735
}
736-
737-
if (auto lhsT = dyn_cast_or_null<stablehlo::TransposeOp>(lhs.getDefiningOp())) {
736+
737+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
738738
auto lhsInput = lhsT.getOperand();
739739
if (lhsInput == rhs && isTrueTranspose(lhsT))
740740
return State::GUARANTEED;

test/lit_tests/structured_tensors/transpose_symmetric.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ func.func @pass3() -> tensor<3x3xf32> {
4646
// CHECK-NEXT: return %2 : tensor<3x3xf32>
4747
// CHECK-NEXT: }
4848

49+
func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
50+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
51+
%1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
52+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
53+
return %2 : tensor<2x2xf32>
54+
}
55+
56+
// CHECK: func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
57+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
58+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
59+
// CHECK-NEXT: return %1 : tensor<2x2xf32>
60+
// CHECK-NEXT: }
61+
62+
func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
63+
%1 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
64+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
65+
return %2 : tensor<2x2xf32>
66+
}
67+
68+
// CHECK: func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
69+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
70+
// CHECK-NEXT: return %0 : tensor<2x2xf32>
71+
// CHECK-NEXT: }
72+
4973
func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5074
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
5175
%1 = stablehlo.subtract %arg0, %0 : tensor<2x2xf32>
@@ -59,3 +83,17 @@ func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5983
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
6084
// CHECK-NEXT: return %2 : tensor<2x2xf32>
6185
// CHECK-NEXT: }
86+
87+
func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
88+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
89+
%1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
90+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
91+
return %2 : tensor<2x2xf32>
92+
}
93+
94+
// CHECK: func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
95+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
96+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
97+
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
98+
// CHECK-NEXT: return %2 : tensor<2x2xf32>
99+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)