@@ -672,8 +672,6 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
672672 return perm.size () == 2 && perm[0 ] == 1 && perm[1 ] == 0 ;
673673 };
674674
675- // TODO: check for dot_general as well
676-
677675 if (auto broadcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(op)) {
678676 auto operand = broadcastOp.getOperand ();
679677 auto operandTy = cast<RankedTensorType>(operand.getType ());
@@ -710,34 +708,35 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710708 }
711709 }
712710
713- // A(A^T) will always be symmetric
714- if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
711+ // A x (A^T) / (A^T) x A will always be symmetric
712+ if (auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(op)) {
713+ auto dotDimNumbers = dotOp.getDotDimensionNumbers ();
715714 auto lhs = dotOp.getLhs ();
716715 auto rhs = dotOp.getRhs ();
717716
718- auto lhsDims = dotOp.getDotDimensionNumbers ().getLhsContractingDimensions ();
719- auto rhsDims = dotOp.getDotDimensionNumbers ().getRhsContractingDimensions ();
720-
721- if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType ());
722- lhsType && lhsType.hasRank () && lhsType.getRank () == 2 ) {
717+ auto lhsCDims = dotDimNumbers.getLhsContractingDimensions ();
718+ auto rhsCDims = dotDimNumbers.getRhsContractingDimensions ();
723719
724- if (rhs == lhs) {
725- if (lhsDims.size () == 1 && lhsDims == rhsDims) {
720+ if (dotDimNumbers.getLhsBatchingDimensions ().size () == 0 &&
721+ dotDimNumbers.getRhsBatchingDimensions ().size () == 0 &&
722+ lhsCDims.size () == 1 && rhsCDims.size () == 1 ) {
723+ if (lhs == rhs) {
724+ if (lhsCDims[0 ] == rhsCDims[0 ]) {
726725 return State::GUARANTEED;
727726 }
728727 }
729728
730- if (lhsDims.size () == 1 && lhsDims != rhsDims) {
731- if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
732- auto rhsInput = rhsT.getOperand ();
733- if (rhsInput == lhs && isTrueTranspose (rhsT))
734- return State::GUARANTEED;
729+ if (auto lhsT = lhs.getDefiningOp <stablehlo::TransposeOp>()) {
730+ if (isTrueTranspose (lhsT) && lhsT.getOperand () == rhs &&
731+ lhsCDims[0 ] == 1 - rhsCDims[0 ]) {
732+ return State::GUARANTEED;
735733 }
734+ }
736735
737- if (auto lhsT = lhs .getDefiningOp <stablehlo::TransposeOp>()) {
738- auto lhsInput = lhsT .getOperand ();
739- if (lhsInput == rhs && isTrueTranspose (lhsT))
740- return State::GUARANTEED;
736+ if (auto rhsT = rhs .getDefiningOp <stablehlo::TransposeOp>()) {
737+ if ( isTrueTranspose (rhsT) && rhsT .getOperand () == lhs &&
738+ lhsCDims[ 0 ] == 1 - rhsCDims[ 0 ]) {
739+ return State::GUARANTEED;
741740 }
742741 }
743742 }
0 commit comments