Skip to content

Commit 81bad78

Browse files
snonkavik-pal
authored andcommitted
dot general case
1 parent 88ee975 commit 81bad78

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,39 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710710
}
711711
}
712712

713+
// A(A^T) will always be symmetric
714+
if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
715+
auto lhs = dotOp.getLhs();
716+
auto rhs = dotOp.getRhs();
717+
718+
auto lhsDims = dotOp.getDotDimensionNumbers().getLhsContractingDimensions();
719+
auto rhsDims = dotOp.getDotDimensionNumbers().getRhsContractingDimensions();
720+
721+
if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType());
722+
lhsType && lhsType.hasRank() && lhsType.getRank() == 2) {
723+
724+
if (rhs == lhs) {
725+
if (lhsDims.size() == 1 && lhsDims == rhsDims) {
726+
return State::GUARANTEED;
727+
}
728+
}
729+
730+
if (lhsDims.size() == 1 && lhsDims != rhsDims) {
731+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
732+
auto rhsInput = rhsT.getOperand();
733+
if (rhsInput == lhs && isTrueTranspose(rhsT))
734+
return State::GUARANTEED;
735+
}
736+
737+
if (auto lhsT = dyn_cast_or_null<stablehlo::TransposeOp>(lhs.getDefiningOp())) {
738+
auto lhsInput = lhsT.getOperand();
739+
if (lhsInput == rhs && isTrueTranspose(lhsT))
740+
return State::GUARANTEED;
741+
}
742+
}
743+
}
744+
}
745+
713746
bool recursiveCheck = false;
714747

715748
// elementwise ops

0 commit comments

Comments
 (0)