Skip to content

Commit 3028d0b

Browse files
committed
fix: update to new version
1 parent a8ceb9c commit 3028d0b

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -672,8 +672,6 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
672672
return perm.size() == 2 && perm[0] == 1 && perm[1] == 0;
673673
};
674674

675-
// TODO: check for dot_general as well
676-
677675
if (auto broadcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(op)) {
678676
auto operand = broadcastOp.getOperand();
679677
auto operandTy = cast<RankedTensorType>(operand.getType());
@@ -710,34 +708,35 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710708
}
711709
}
712710

713-
// A(A^T) will always be symmetric
714-
if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
711+
// A x (A^T) / (A^T) x A will always be symmetric
712+
if (auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(op)) {
713+
auto dotDimNumbers = dotOp.getDotDimensionNumbers();
715714
auto lhs = dotOp.getLhs();
716715
auto rhs = dotOp.getRhs();
717716

718-
auto lhsDims = dotOp.getDotDimensionNumbers().getLhsContractingDimensions();
719-
auto rhsDims = dotOp.getDotDimensionNumbers().getRhsContractingDimensions();
720-
721-
if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType());
722-
lhsType && lhsType.hasRank() && lhsType.getRank() == 2) {
717+
auto lhsCDims = dotDimNumbers.getLhsContractingDimensions();
718+
auto rhsCDims = dotDimNumbers.getRhsContractingDimensions();
723719

724-
if (rhs == lhs) {
725-
if (lhsDims.size() == 1 && lhsDims == rhsDims) {
720+
if (dotDimNumbers.getLhsBatchingDimensions().size() == 0 &&
721+
dotDimNumbers.getRhsBatchingDimensions().size() == 0 &&
722+
lhsCDims.size() == 1 && rhsCDims.size() == 1) {
723+
if (lhs == rhs) {
724+
if (lhsCDims[0] == rhsCDims[0]) {
726725
return State::GUARANTEED;
727726
}
728727
}
729728

730-
if (lhsDims.size() == 1 && lhsDims != rhsDims) {
731-
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
732-
auto rhsInput = rhsT.getOperand();
733-
if (rhsInput == lhs && isTrueTranspose(rhsT))
734-
return State::GUARANTEED;
729+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
730+
if (isTrueTranspose(lhsT) && lhsT.getOperand() == rhs &&
731+
lhsCDims[0] == 1 - rhsCDims[0]) {
732+
return State::GUARANTEED;
735733
}
734+
}
736735

737-
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
738-
auto lhsInput = lhsT.getOperand();
739-
if (lhsInput == rhs && isTrueTranspose(lhsT))
740-
return State::GUARANTEED;
736+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
737+
if (isTrueTranspose(rhsT) && rhsT.getOperand() == lhs &&
738+
lhsCDims[0] == 1 - rhsCDims[0]) {
739+
return State::GUARANTEED;
741740
}
742741
}
743742
}

test/lit_tests/structured_tensors/transpose_symmetric.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5555

5656
// CHECK: func.func @pass4(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
5757
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
58-
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
58+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
5959
// CHECK-NEXT: return %1 : tensor<2x2xf32>
6060
// CHECK-NEXT: }
6161

@@ -66,7 +66,7 @@ func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
6666
}
6767

6868
// CHECK: func.func @pass5(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
69-
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
69+
// CHECK-NEXT: %0 = stablehlo.dot_general %arg0, %arg0, contracting_dims = [0] x [0], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed GUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
7070
// CHECK-NEXT: return %0 : tensor<2x2xf32>
7171
// CHECK-NEXT: }
7272

@@ -93,7 +93,7 @@ func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
9393

9494
// CHECK: func.func @fail2(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> {
9595
// CHECK-NEXT: %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
96-
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.guaranteed_symmetric = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
96+
// CHECK-NEXT: %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] {enzymexla.symmetric_matrix = [#enzymexla<guaranteed NOTGUARANTEED>]} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
9797
// CHECK-NEXT: %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf32>) -> tensor<2x2xf32>
9898
// CHECK-NEXT: return %2 : tensor<2x2xf32>
9999
// CHECK-NEXT: }

0 commit comments

Comments
 (0)