File tree Expand file tree Collapse file tree 1 file changed +33
-0
lines changed
Expand file tree Collapse file tree 1 file changed +33
-0
lines changed Original file line number Diff line number Diff line change @@ -710,6 +710,39 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710710 }
711711 }
712712
713+ // A(A^T) will always be symmetric
714+ if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
715+ auto lhs = dotOp.getLhs ();
716+ auto rhs = dotOp.getRhs ();
717+
718+ auto lhsDims = dotOp.getDotDimensionNumbers ().getLhsContractingDimensions ();
719+ auto rhsDims = dotOp.getDotDimensionNumbers ().getRhsContractingDimensions ();
720+
721+ if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType ());
722+ lhsType && lhsType.hasRank () && lhsType.getRank () == 2 ) {
723+
724+ if (rhs == lhs) {
725+ if (lhsDims.size () == 1 && lhsDims == rhsDims) {
726+ return State::GUARANTEED;
727+ }
728+ }
729+
730+ if (lhsDims.size () == 1 && lhsDims != rhsDims) {
731+ if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
732+ auto rhsInput = rhsT.getOperand ();
733+ if (rhsInput == lhs && isTrueTranspose (rhsT))
734+ return State::GUARANTEED;
735+ }
736+
737+ if (auto lhsT = dyn_cast_or_null<stablehlo::TransposeOp>(lhs.getDefiningOp ())) {
738+ auto lhsInput = lhsT.getOperand ();
739+ if (lhsInput == rhs && isTrueTranspose (lhsT))
740+ return State::GUARANTEED;
741+ }
742+ }
743+ }
744+ }
745+
713746 bool recursiveCheck = false ;
714747
715748 // elementwise ops
You can’t perform that action at this time.
0 commit comments