Skip to content

Commit a9caa01

Browse files
snonkavik-pal
andauthored
dot general case for symm (#1620)
* dot general case * add test and fmt * fix: update to new version --------- Co-authored-by: Avik Pal <[email protected]>
1 parent 4dca341 commit a9caa01

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,6 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
672672
return perm.size() == 2 && perm[0] == 1 && perm[1] == 0;
673673
};
674674

675-
// TODO: check for dot_general as well
676-
677675
if (auto broadcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(op)) {
678676
auto operand = broadcastOp.getOperand();
679677
auto operandTy = cast<RankedTensorType>(operand.getType());
@@ -710,6 +708,38 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710708
}
711709
}
712710

711+
// A x (A^T) / (A^T) x A will always be symmetric
712+
if (auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(op)) {
713+
auto dotDimNumbers = dotOp.getDotDimensionNumbers();
714+
auto lhs = dotOp.getLhs();
715+
auto rhs = dotOp.getRhs();
716+
717+
auto lhsCDims = dotDimNumbers.getLhsContractingDimensions();
718+
auto rhsCDims = dotDimNumbers.getRhsContractingDimensions();
719+
720+
if (dotDimNumbers.getLhsBatchingDimensions().size() == 0 &&
721+
dotDimNumbers.getRhsBatchingDimensions().size() == 0 &&
722+
lhsCDims.size() == 1 && rhsCDims.size() == 1) {
723+
if (lhs == rhs && lhsCDims[0] == rhsCDims[0]) {
724+
return State::GUARANTEED;
725+
}
726+
727+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
728+
if (isTrueTranspose(lhsT) && lhsT.getOperand() == rhs &&
729+
lhsCDims[0] == 1 - rhsCDims[0]) {
730+
return State::GUARANTEED;
731+
}
732+
}
733+
734+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
735+
if (isTrueTranspose(rhsT) && rhsT.getOperand() == lhs &&
736+
lhsCDims[0] == 1 - rhsCDims[0]) {
737+
return State::GUARANTEED;
738+
}
739+
}
740+
}
741+
}
742+
713743
bool recursiveCheck = false;
714744

715745
// elementwise ops

test/lit_tests/structured_tensors/transpose_symmetric.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ func.func @pass3() -> tensor<3x3xf32> {
4646
// CHECK-NEXT: return %2 : tensor<3x3xf32>
4747
// CHECK-NEXT: }
4848

49+
func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
50+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
51+
%1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
52+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
53+
return %2 : tensor<2x2xf32>
54+
}
55+
56+
// CHECK: func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
57+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
58+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
59+
// CHECK-NEXT: return %1 : tensor<2x2xf32>
60+
// CHECK-NEXT: }
61+
62+
func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
63+
%1 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
64+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
65+
return %2 : tensor<2x2xf32>
66+
}
67+
68+
// CHECK: func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
69+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
70+
// CHECK-NEXT: return %0 : tensor<2x2xf32>
71+
// CHECK-NEXT: }
72+
4973
func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5074
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
5175
%1 = stablehlo.subtract %arg0, %0 : tensor<2x2xf32>
@@ -59,3 +83,17 @@ func.func @fail1(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5983
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
6084
// CHECK-NEXT: return %2 : tensor<2x2xf32>
6185
// CHECK-NEXT: }
86+
87+
func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
88+
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
89+
%1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
90+
%2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
91+
return %2 : tensor<2x2xf32>
92+
}
93+
94+
// CHECK: func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
95+
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
96+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
97+
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
98+
// CHECK-NEXT: return %2 : tensor<2x2xf32>
99+
// CHECK-NEXT: }

0 commit comments

Comments
 (0)