@@ -672,8 +672,6 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
672672 return perm.size () == 2 && perm[0 ] == 1 && perm[1 ] == 0 ;
673673 };
674674
675- // TODO: check for dot_general as well
676-
677675 if (auto broadcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(op)) {
678676 auto operand = broadcastOp.getOperand ();
679677 auto operandTy = cast<RankedTensorType>(operand.getType ());
@@ -710,34 +708,33 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
710708 }
711709 }
712710
713- // A(A^T) will always be symmetric
714- if (auto dotOp = dyn_cast_or_null<stablehlo::DotGeneralOp>(op)) {
711+ // A x (A^T) / (A^T) x A will always be symmetric
712+ if (auto dotOp = dyn_cast<stablehlo::DotGeneralOp>(op)) {
713+ auto dotDimNumbers = dotOp.getDotDimensionNumbers ();
715714 auto lhs = dotOp.getLhs ();
716715 auto rhs = dotOp.getRhs ();
717716
718- auto lhsDims = dotOp. getDotDimensionNumbers () .getLhsContractingDimensions ();
719- auto rhsDims = dotOp. getDotDimensionNumbers () .getRhsContractingDimensions ();
717+ auto lhsCDims = dotDimNumbers .getLhsContractingDimensions ();
718+ auto rhsCDims = dotDimNumbers .getRhsContractingDimensions ();
720719
721- if (auto lhsType = dyn_cast_or_null<ShapedType>(lhs.getType ());
722- lhsType && lhsType.hasRank () && lhsType.getRank () == 2 ) {
720+ if (dotDimNumbers.getLhsBatchingDimensions ().size () == 0 &&
721+ dotDimNumbers.getRhsBatchingDimensions ().size () == 0 &&
722+ lhsCDims.size () == 1 && rhsCDims.size () == 1 ) {
723+ if (lhs == rhs && lhsCDims[0 ] == rhsCDims[0 ]) {
724+ return State::GUARANTEED;
725+ }
723726
724- if (rhs == lhs) {
725- if (lhsDims.size () == 1 && lhsDims == rhsDims) {
727+ if (auto lhsT = lhs.getDefiningOp <stablehlo::TransposeOp>()) {
728+ if (isTrueTranspose (lhsT) && lhsT.getOperand () == rhs &&
729+ lhsCDims[0 ] == 1 - rhsCDims[0 ]) {
726730 return State::GUARANTEED;
727731 }
728732 }
729733
730- if (lhsDims.size () == 1 && lhsDims != rhsDims) {
731- if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
732- auto rhsInput = rhsT.getOperand ();
733- if (rhsInput == lhs && isTrueTranspose (rhsT))
734- return State::GUARANTEED;
735- }
736-
737- if (auto lhsT = lhs.getDefiningOp <stablehlo::TransposeOp>()) {
738- auto lhsInput = lhsT.getOperand ();
739- if (lhsInput == rhs && isTrueTranspose (lhsT))
740- return State::GUARANTEED;
734+ if (auto rhsT = rhs.getDefiningOp <stablehlo::TransposeOp>()) {
735+ if (isTrueTranspose (rhsT) && rhsT.getOperand () == lhs &&
736+ lhsCDims[0 ] == 1 - rhsCDims[0 ]) {
737+ return State::GUARANTEED;
741738 }
742739 }
743740 }
0 commit comments