@@ -25871,6 +25871,227 @@ struct BinaryNegatedOperandsSimplify
2587125871 }
2587225872};
2587325873
25874+ // currently limited to non-batched dot_general
25875+ struct DotGeneralToSyrk
25876+ : public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25877+ DotGeneralToSyrk> {
25878+ using CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25879+ DotGeneralToSyrk>::CheckedOpRewritePattern;
25880+
25881+ LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
25882+ PatternRewriter &rewriter) const {
25883+ auto dotDims = op.getDotDimensionNumbers();
25884+ auto lhs = op.getLhs();
25885+ auto rhs = op.getRhs();
25886+
25887+ if (dotDims.getLhsBatchingDimensions().size() != 0 ||
25888+ dotDims.getRhsBatchingDimensions().size() != 0) {
25889+ return failure();
25890+ }
25891+
25892+ if (dotDims.getLhsContractingDimensions().size() != 1 ||
25893+ dotDims.getRhsContractingDimensions().size() != 1) {
25894+ return failure();
25895+ }
25896+
25897+ // check that transpose dimensions are [1,0]
25898+ auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool {
25899+ auto perm = tOp.getPermutation();
25900+ return perm.size() == 2 && perm[0] == 1 && perm[1] == 0;
25901+ };
25902+
25903+ auto lhsContractingDim = dotDims.getLhsContractingDimensions()[0];
25904+ auto rhsContractingDim = dotDims.getRhsContractingDimensions()[0];
25905+
25906+ Value syrkInput;
25907+ enzymexla::LapackTranspose lapackTranspose;
25908+
25909+ if (lhs == rhs && lhsContractingDim == rhsContractingDim) {
25910+ syrkInput = lhs;
25911+ if (lhsContractingDim == 1) {
25912+ lapackTranspose = enzymexla::LapackTranspose::none;
25913+ } else {
25914+ lapackTranspose = enzymexla::LapackTranspose::transpose;
25915+ }
25916+ }
25917+
25918+ if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
25919+ if (isTrueTranspose(lhsT) && lhsT.getOperand() == rhs &&
25920+ lhsContractingDim == 1 - rhsContractingDim) {
25921+ syrkInput = rhs;
25922+ if (rhsContractingDim == 1) {
25923+ lapackTranspose = enzymexla::LapackTranspose::none;
25924+ } else {
25925+ lapackTranspose = enzymexla::LapackTranspose::transpose;
25926+ }
25927+ }
25928+ }
25929+
25930+ if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
25931+ if (isTrueTranspose(rhsT) && rhsT.getOperand() == lhs &&
25932+ rhsContractingDim == 1 - lhsContractingDim) {
25933+ syrkInput = lhs;
25934+ if (lhsContractingDim == 0) {
25935+ lapackTranspose = enzymexla::LapackTranspose::transpose;
25936+ } else {
25937+ lapackTranspose = enzymexla::LapackTranspose::none;
25938+ }
25939+ }
25940+ }
25941+
25942+ if (!syrkInput)
25943+ return failure();
25944+
25945+ auto elemType =
25946+ cast<RankedTensorType>(syrkInput.getType()).getElementType();
25947+ auto alphaType = RankedTensorType::get({}, elemType);
25948+
25949+ auto syrkOp = enzymexla::SyrkOp::create(
25950+ rewriter, op.getLoc(), op.getResult().getType(), syrkInput,
25951+ stablehlo::ConstantOp::create(
25952+ rewriter, op.getLoc(), op.getType(),
25953+ cast<ElementsAttr>(makeAttr(op.getType(), 0))),
25954+ stablehlo::ConstantOp::create(
25955+ rewriter, op.getLoc(), alphaType,
25956+ cast<ElementsAttr>(makeAttr(alphaType, 1))),
25957+ stablehlo::ConstantOp::create(
25958+ rewriter, op.getLoc(), alphaType,
25959+ cast<ElementsAttr>(makeAttr(alphaType, 0))),
25960+ enzymexla::LapackUploAttr::get(op.getContext(),
25961+ enzymexla::LapackUplo::U),
25962+ enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25963+ rewriter.getUnitAttr());
25964+ rewriter.replaceOp(op, syrkOp.getResult());
25965+ return success();
25966+ }
25967+ };
25968+
25969+ struct TransposeSyrkToSyrk
25970+ : public CheckedOpRewritePattern<enzymexla::SyrkOp, TransposeSyrkToSyrk> {
25971+ using CheckedOpRewritePattern<enzymexla::SyrkOp,
25972+ TransposeSyrkToSyrk>::CheckedOpRewritePattern;
25973+
25974+ LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
25975+ PatternRewriter &rewriter) const {
25976+ auto input = op.getA();
25977+ if (cast<RankedTensorType>(input.getType()).getRank() != 2)
25978+ return failure(); // support only rank 2 matrices for now
25979+
25980+ auto transposeOp = input.getDefiningOp<stablehlo::TransposeOp>();
25981+ if (!transposeOp)
25982+ return failure();
25983+
25984+ auto perm = transposeOp.getPermutation();
25985+ if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0)
25986+ return failure();
25987+
25988+ enzymexla::LapackTranspose lapackTranspose;
25989+ switch (op.getTranspose()) {
25990+ case enzymexla::LapackTranspose::none:
25991+ lapackTranspose = enzymexla::LapackTranspose::transpose;
25992+ break;
25993+ default:
25994+ lapackTranspose = enzymexla::LapackTranspose::none;
25995+ }
25996+
25997+ rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
25998+ op, op.getResult().getType(), transposeOp.getOperand(), op.getC(),
25999+ op.getAlpha(), op.getBeta(), op.getUploAttr(),
26000+ enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
26001+ op.getFillAttr());
26002+ return success();
26003+ }
26004+ };
26005+
26006+ struct FuseMulIntoSyrk
26007+ : public CheckedOpRewritePattern<stablehlo::MulOp, FuseMulIntoSyrk> {
26008+ using CheckedOpRewritePattern<stablehlo::MulOp,
26009+ FuseMulIntoSyrk>::CheckedOpRewritePattern;
26010+
26011+ LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
26012+ PatternRewriter &rewriter) const {
26013+ auto lhs = op.getLhs();
26014+ auto rhs = op.getRhs();
26015+
26016+ enzymexla::SyrkOp syrkOp;
26017+ Value other;
26018+
26019+ if (auto lhsSyrk = lhs.getDefiningOp<enzymexla::SyrkOp>()) {
26020+ syrkOp = lhsSyrk;
26021+ other = rhs;
26022+ } else if (auto rhsSyrk = rhs.getDefiningOp<enzymexla::SyrkOp>()) {
26023+ syrkOp = rhsSyrk;
26024+ other = lhs;
26025+ } else {
26026+ return failure();
26027+ }
26028+
26029+ Value scalarVal =
26030+ stablehlo::getScalarValue(other.getDefiningOp(), rewriter);
26031+ if (!scalarVal)
26032+ return failure();
26033+
26034+ auto newBeta = stablehlo::MulOp::create(rewriter, op.getLoc(),
26035+ syrkOp.getBeta(), scalarVal);
26036+ auto newAlpha = stablehlo::MulOp::create(rewriter, op.getLoc(),
26037+ syrkOp.getAlpha(), scalarVal);
26038+
26039+ rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
26040+ op, syrkOp.getType(), syrkOp.getA(), syrkOp.getC(), newAlpha, newBeta,
26041+ syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
26042+ return success();
26043+ }
26044+ };
26045+
26046+ struct FuseAddIntoSyrk
26047+ : public CheckedOpRewritePattern<stablehlo::AddOp,
26048+ FuseAddIntoSyrk>::CheckedOpRewritePattern {
26049+ using CheckedOpRewritePattern<stablehlo::AddOp,
26050+ FuseAddIntoSyrk>::CheckedOpRewritePattern;
26051+
26052+ LogicalResult matchAndRewriteImpl(stablehlo::AddOp op,
26053+ PatternRewriter &rewriter) const {
26054+ auto lhs = op.getLhs();
26055+ auto rhs = op.getRhs();
26056+
26057+ enzymexla::SyrkOp syrkOp;
26058+ Value other;
26059+
26060+ if (auto lhsSyrk = lhs.getDefiningOp<enzymexla::SyrkOp>()) {
26061+ syrkOp = lhsSyrk;
26062+ other = rhs;
26063+ } else if (auto rhsSyrk = rhs.getDefiningOp<enzymexla::SyrkOp>()) {
26064+ syrkOp = rhsSyrk;
26065+ other = lhs;
26066+ } else {
26067+ return failure();
26068+ }
26069+
26070+ // we can fuse this addition iff the other operand is a symmetric matrix
26071+ if (!canApplySymmetricPattern(other, rewriter))
26072+ return failure();
26073+
26074+ auto oldBeta = syrkOp.getBeta();
26075+ auto bcastedBeta = stablehlo::BroadcastInDimOp::create(
26076+ rewriter, op.getLoc(), syrkOp.getType(), oldBeta,
26077+ rewriter.getDenseI64ArrayAttr({}));
26078+
26079+ auto scaledC = stablehlo::MulOp::create(rewriter, op.getLoc(),
26080+ syrkOp.getC(), bcastedBeta);
26081+
26082+ auto newC = stablehlo::AddOp::create(rewriter, op.getLoc(), scaledC, other);
26083+
26084+ auto newBeta = stablehlo::ConstantOp::create(
26085+ rewriter, op.getLoc(), oldBeta.getType(),
26086+ cast<ElementsAttr>(makeAttr(oldBeta.getType(), 1)));
26087+
26088+ rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
26089+ op, syrkOp.getType(), syrkOp.getA(), newC, syrkOp.getAlpha(), newBeta,
26090+ syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
26091+ return success();
26092+ }
26093+ };
26094+
2587426095/////////////// End Imported from stablehlo
2587526096
2587626097// clang-format off
@@ -26423,6 +26644,12 @@ struct EnzymeHLOOptPass
2642326644 patterns.add<TransposeSymmetricSimplify>(context);
2642426645 patterns.add<FactorScalarsInDotGeneral>(context);
2642526646
26647+ // syrk patterns
26648+ // currently disabled since lowering is missing
26649+ // patterns.add<DotGeneralToSyrk>(context);
26650+ patterns.add<TransposeSyrkToSyrk, FuseMulIntoSyrk, FuseAddIntoSyrk>(
26651+ context);
26652+
2642626653 // clang-format off
2642726654 patterns.add<
2642826655 WhileRepeatedInductionReduction,
0 commit comments