Skip to content

Commit 88ee975

Browse files
authored
feat: common syrk optimization patterns (#1653)
* feat: common syrk optimization patterns * feat: syrk output is always symmetric * fix: address review comment * fix: docs
1 parent 859c17f commit 88ee975

File tree

7 files changed

+372
-44
lines changed

7 files changed

+372
-44
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,11 @@ def SyrkOp: EnzymeXLA_Op<"blas.syrk", [Pure, SameOperandsAndResultElementType]>
427427
let summary = "Multiplication involving a symmetric matrix";
428428

429429
let description = [{
430-
C := alpha*A*A^T + beta*C, or C := alpha*A^T*A + beta*C, where alpha and beta are scalars. C must be a n x n symmetric matrix."
430+
C := alpha*A*A^T + beta*C, or C := alpha*A^T*A + beta*C, where alpha and beta are
431+
scalars. C must be a n x n symmetric matrix.
432+
433+
If `fill` is present, then both the upper and lower triangles of the matrix are filled.
434+
Otherwise the values in the non-uplo part of the matrix are undefined.
431435
}];
432436

433437
let arguments = (ins
@@ -436,7 +440,8 @@ def SyrkOp: EnzymeXLA_Op<"blas.syrk", [Pure, SameOperandsAndResultElementType]>
436440
TensorFloat:$alpha,
437441
TensorFloat:$beta,
438442
EnzymeXLA_LapackUploAttr:$uplo,
439-
DefaultValuedAttr<EnzymeXLA_LapackTransposeAttr, "::mlir::enzymexla::LapackTranspose::none">:$transpose
443+
DefaultValuedAttr<EnzymeXLA_LapackTransposeAttr, "::mlir::enzymexla::LapackTranspose::none">:$transpose,
444+
OptionalAttr<UnitAttr>:$fill
440445
);
441446

442447
let results = (outs

src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp

Lines changed: 2 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3166,44 +3166,6 @@ struct IfOpEnzymeOpsRemover
31663166
}
31673167
};
31683168

3169-
Value getScalarInitValue(Operation *op, OpBuilder &builder) {
3170-
if (!op)
3171-
return nullptr;
3172-
3173-
// Splatted Constant
3174-
SplatElementsAttr elems;
3175-
if (matchPattern(op, m_Constant(&elems))) {
3176-
auto scalarElemType = RankedTensorType::get(
3177-
{}, cast<TensorType>(op->getResult(0).getType()).getElementType());
3178-
auto constInit = ConstantOp::create(builder, op->getLoc(), scalarElemType,
3179-
elems.resizeSplat(scalarElemType));
3180-
return constInit;
3181-
}
3182-
3183-
// BroadcastInDim / Reshape
3184-
if (isa<stablehlo::BroadcastInDimOp, stablehlo::ReshapeOp>(op)) {
3185-
if (cast<RankedTensorType>(op->getOperand(0).getType()).getRank() == 0) {
3186-
return op->getOperand(0);
3187-
}
3188-
}
3189-
3190-
// Convert
3191-
if (auto convertOp = dyn_cast<stablehlo::ConvertOp>(op)) {
3192-
auto scalar =
3193-
getScalarInitValue(convertOp.getOperand().getDefiningOp(), builder);
3194-
if (scalar) {
3195-
auto convertOutElemType =
3196-
cast<RankedTensorType>(convertOp.getResult().getType())
3197-
.getElementType();
3198-
return stablehlo::ConvertOp::create(
3199-
builder, op->getLoc(), RankedTensorType::get({}, convertOutElemType),
3200-
scalar);
3201-
}
3202-
}
3203-
3204-
return nullptr;
3205-
}
3206-
32073169
struct SHLOReduceOpBatchInterface
32083170
: public BatchOpInterface::ExternalModel<SHLOReduceOpBatchInterface,
32093171
ReduceOp> {
@@ -3234,8 +3196,7 @@ struct SHLOReduceOpBatchInterface
32343196
newReduceInits.reserve(reduceOp.getInitValues().size());
32353197
for (auto opValue : reduceOp.getInitValues()) {
32363198
auto batchedInit = mapper.lookup(opValue);
3237-
auto scalarInit =
3238-
getScalarInitValue(batchedInit.getDefiningOp(), builder);
3199+
auto scalarInit = getScalarValue(batchedInit.getDefiningOp(), builder);
32393200
if (!scalarInit) {
32403201
// TODO: we need to support broadcasting inits, or do we?
32413202
src->emitError("Unsupported reduce init for batched reduce");
@@ -3316,8 +3277,7 @@ struct SHLOReduceWindowOpBatchInterface
33163277
newReduceWindowInits.reserve(reduceWindowOp.getInitValues().size());
33173278
for (auto opValue : reduceWindowOp.getInitValues()) {
33183279
auto batchedInit = mapper.lookup(opValue);
3319-
auto scalarInit =
3320-
getScalarInitValue(batchedInit.getDefiningOp(), builder);
3280+
auto scalarInit = getScalarValue(batchedInit.getDefiningOp(), builder);
33213281
if (!scalarInit) {
33223282
src->emitError(
33233283
"Unsupported reduce window init for batched reduce window");

src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25871,6 +25871,227 @@ struct BinaryNegatedOperandsSimplify
2587125871
}
2587225872
};
2587325873

25874+
// currently limited to non-batched dot_general
25875+
struct DotGeneralToSyrk
25876+
: public CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25877+
DotGeneralToSyrk> {
25878+
using CheckedOpRewritePattern<stablehlo::DotGeneralOp,
25879+
DotGeneralToSyrk>::CheckedOpRewritePattern;
25880+
25881+
LogicalResult matchAndRewriteImpl(stablehlo::DotGeneralOp op,
25882+
PatternRewriter &rewriter) const {
25883+
auto dotDims = op.getDotDimensionNumbers();
25884+
auto lhs = op.getLhs();
25885+
auto rhs = op.getRhs();
25886+
25887+
if (dotDims.getLhsBatchingDimensions().size() != 0 ||
25888+
dotDims.getRhsBatchingDimensions().size() != 0) {
25889+
return failure();
25890+
}
25891+
25892+
if (dotDims.getLhsContractingDimensions().size() != 1 ||
25893+
dotDims.getRhsContractingDimensions().size() != 1) {
25894+
return failure();
25895+
}
25896+
25897+
// check that transpose dimensions are [1,0]
25898+
auto isTrueTranspose = [](stablehlo::TransposeOp tOp) -> bool {
25899+
auto perm = tOp.getPermutation();
25900+
return perm.size() == 2 && perm[0] == 1 && perm[1] == 0;
25901+
};
25902+
25903+
auto lhsContractingDim = dotDims.getLhsContractingDimensions()[0];
25904+
auto rhsContractingDim = dotDims.getRhsContractingDimensions()[0];
25905+
25906+
Value syrkInput;
25907+
enzymexla::LapackTranspose lapackTranspose;
25908+
25909+
if (lhs == rhs && lhsContractingDim == rhsContractingDim) {
25910+
syrkInput = lhs;
25911+
if (lhsContractingDim == 1) {
25912+
lapackTranspose = enzymexla::LapackTranspose::none;
25913+
} else {
25914+
lapackTranspose = enzymexla::LapackTranspose::transpose;
25915+
}
25916+
}
25917+
25918+
if (auto lhsT = lhs.getDefiningOp<stablehlo::TransposeOp>()) {
25919+
if (isTrueTranspose(lhsT) && lhsT.getOperand() == rhs &&
25920+
lhsContractingDim == 1 - rhsContractingDim) {
25921+
syrkInput = rhs;
25922+
if (rhsContractingDim == 1) {
25923+
lapackTranspose = enzymexla::LapackTranspose::none;
25924+
} else {
25925+
lapackTranspose = enzymexla::LapackTranspose::transpose;
25926+
}
25927+
}
25928+
}
25929+
25930+
if (auto rhsT = rhs.getDefiningOp<stablehlo::TransposeOp>()) {
25931+
if (isTrueTranspose(rhsT) && rhsT.getOperand() == lhs &&
25932+
rhsContractingDim == 1 - lhsContractingDim) {
25933+
syrkInput = lhs;
25934+
if (lhsContractingDim == 0) {
25935+
lapackTranspose = enzymexla::LapackTranspose::transpose;
25936+
} else {
25937+
lapackTranspose = enzymexla::LapackTranspose::none;
25938+
}
25939+
}
25940+
}
25941+
25942+
if (!syrkInput)
25943+
return failure();
25944+
25945+
auto elemType =
25946+
cast<RankedTensorType>(syrkInput.getType()).getElementType();
25947+
auto alphaType = RankedTensorType::get({}, elemType);
25948+
25949+
auto syrkOp = enzymexla::SyrkOp::create(
25950+
rewriter, op.getLoc(), op.getResult().getType(), syrkInput,
25951+
stablehlo::ConstantOp::create(
25952+
rewriter, op.getLoc(), op.getType(),
25953+
cast<ElementsAttr>(makeAttr(op.getType(), 0))),
25954+
stablehlo::ConstantOp::create(
25955+
rewriter, op.getLoc(), alphaType,
25956+
cast<ElementsAttr>(makeAttr(alphaType, 1))),
25957+
stablehlo::ConstantOp::create(
25958+
rewriter, op.getLoc(), alphaType,
25959+
cast<ElementsAttr>(makeAttr(alphaType, 0))),
25960+
enzymexla::LapackUploAttr::get(op.getContext(),
25961+
enzymexla::LapackUplo::U),
25962+
enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
25963+
rewriter.getUnitAttr());
25964+
rewriter.replaceOp(op, syrkOp.getResult());
25965+
return success();
25966+
}
25967+
};
25968+
25969+
struct TransposeSyrkToSyrk
25970+
: public CheckedOpRewritePattern<enzymexla::SyrkOp, TransposeSyrkToSyrk> {
25971+
using CheckedOpRewritePattern<enzymexla::SyrkOp,
25972+
TransposeSyrkToSyrk>::CheckedOpRewritePattern;
25973+
25974+
LogicalResult matchAndRewriteImpl(enzymexla::SyrkOp op,
25975+
PatternRewriter &rewriter) const {
25976+
auto input = op.getA();
25977+
if (cast<RankedTensorType>(input.getType()).getRank() != 2)
25978+
return failure(); // support only rank 2 matrices for now
25979+
25980+
auto transposeOp = input.getDefiningOp<stablehlo::TransposeOp>();
25981+
if (!transposeOp)
25982+
return failure();
25983+
25984+
auto perm = transposeOp.getPermutation();
25985+
if (perm.size() != 2 || perm[0] != 1 || perm[1] != 0)
25986+
return failure();
25987+
25988+
enzymexla::LapackTranspose lapackTranspose;
25989+
switch (op.getTranspose()) {
25990+
case enzymexla::LapackTranspose::none:
25991+
lapackTranspose = enzymexla::LapackTranspose::transpose;
25992+
break;
25993+
default:
25994+
lapackTranspose = enzymexla::LapackTranspose::none;
25995+
}
25996+
25997+
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
25998+
op, op.getResult().getType(), transposeOp.getOperand(), op.getC(),
25999+
op.getAlpha(), op.getBeta(), op.getUploAttr(),
26000+
enzymexla::LapackTransposeAttr::get(op.getContext(), lapackTranspose),
26001+
op.getFillAttr());
26002+
return success();
26003+
}
26004+
};
26005+
26006+
struct FuseMulIntoSyrk
26007+
: public CheckedOpRewritePattern<stablehlo::MulOp, FuseMulIntoSyrk> {
26008+
using CheckedOpRewritePattern<stablehlo::MulOp,
26009+
FuseMulIntoSyrk>::CheckedOpRewritePattern;
26010+
26011+
LogicalResult matchAndRewriteImpl(stablehlo::MulOp op,
26012+
PatternRewriter &rewriter) const {
26013+
auto lhs = op.getLhs();
26014+
auto rhs = op.getRhs();
26015+
26016+
enzymexla::SyrkOp syrkOp;
26017+
Value other;
26018+
26019+
if (auto lhsSyrk = lhs.getDefiningOp<enzymexla::SyrkOp>()) {
26020+
syrkOp = lhsSyrk;
26021+
other = rhs;
26022+
} else if (auto rhsSyrk = rhs.getDefiningOp<enzymexla::SyrkOp>()) {
26023+
syrkOp = rhsSyrk;
26024+
other = lhs;
26025+
} else {
26026+
return failure();
26027+
}
26028+
26029+
Value scalarVal =
26030+
stablehlo::getScalarValue(other.getDefiningOp(), rewriter);
26031+
if (!scalarVal)
26032+
return failure();
26033+
26034+
auto newBeta = stablehlo::MulOp::create(rewriter, op.getLoc(),
26035+
syrkOp.getBeta(), scalarVal);
26036+
auto newAlpha = stablehlo::MulOp::create(rewriter, op.getLoc(),
26037+
syrkOp.getAlpha(), scalarVal);
26038+
26039+
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
26040+
op, syrkOp.getType(), syrkOp.getA(), syrkOp.getC(), newAlpha, newBeta,
26041+
syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
26042+
return success();
26043+
}
26044+
};
26045+
26046+
struct FuseAddIntoSyrk
26047+
: public CheckedOpRewritePattern<stablehlo::AddOp,
26048+
FuseAddIntoSyrk>::CheckedOpRewritePattern {
26049+
using CheckedOpRewritePattern<stablehlo::AddOp,
26050+
FuseAddIntoSyrk>::CheckedOpRewritePattern;
26051+
26052+
LogicalResult matchAndRewriteImpl(stablehlo::AddOp op,
26053+
PatternRewriter &rewriter) const {
26054+
auto lhs = op.getLhs();
26055+
auto rhs = op.getRhs();
26056+
26057+
enzymexla::SyrkOp syrkOp;
26058+
Value other;
26059+
26060+
if (auto lhsSyrk = lhs.getDefiningOp<enzymexla::SyrkOp>()) {
26061+
syrkOp = lhsSyrk;
26062+
other = rhs;
26063+
} else if (auto rhsSyrk = rhs.getDefiningOp<enzymexla::SyrkOp>()) {
26064+
syrkOp = rhsSyrk;
26065+
other = lhs;
26066+
} else {
26067+
return failure();
26068+
}
26069+
26070+
// we can fuse this addition iff the other operand is a symmetric matrix
26071+
if (!canApplySymmetricPattern(other, rewriter))
26072+
return failure();
26073+
26074+
auto oldBeta = syrkOp.getBeta();
26075+
auto bcastedBeta = stablehlo::BroadcastInDimOp::create(
26076+
rewriter, op.getLoc(), syrkOp.getType(), oldBeta,
26077+
rewriter.getDenseI64ArrayAttr({}));
26078+
26079+
auto scaledC = stablehlo::MulOp::create(rewriter, op.getLoc(),
26080+
syrkOp.getC(), bcastedBeta);
26081+
26082+
auto newC = stablehlo::AddOp::create(rewriter, op.getLoc(), scaledC, other);
26083+
26084+
auto newBeta = stablehlo::ConstantOp::create(
26085+
rewriter, op.getLoc(), oldBeta.getType(),
26086+
cast<ElementsAttr>(makeAttr(oldBeta.getType(), 1)));
26087+
26088+
rewriter.replaceOpWithNewOp<enzymexla::SyrkOp>(
26089+
op, syrkOp.getType(), syrkOp.getA(), newC, syrkOp.getAlpha(), newBeta,
26090+
syrkOp.getUploAttr(), syrkOp.getTransposeAttr(), syrkOp.getFillAttr());
26091+
return success();
26092+
}
26093+
};
26094+
2587426095
/////////////// End Imported from stablehlo
2587526096

2587626097
// clang-format off
@@ -26423,6 +26644,12 @@ struct EnzymeHLOOptPass
2642326644
patterns.add<TransposeSymmetricSimplify>(context);
2642426645
patterns.add<FactorScalarsInDotGeneral>(context);
2642526646

26647+
// syrk patterns
26648+
// currently disabled since lowering is missing
26649+
// patterns.add<DotGeneralToSyrk>(context);
26650+
patterns.add<TransposeSyrkToSyrk, FuseMulIntoSyrk, FuseAddIntoSyrk>(
26651+
context);
26652+
2642626653
// clang-format off
2642726654
patterns.add<
2642826655
WhileRepeatedInductionReduction,

src/enzyme_ad/jax/TransformOps/TransformOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2519,3 +2519,23 @@ def ApplyMultiplyNegatedOperandsSimplifyPatterns : EnzymeHLOPatternOp<
25192519
"multiply_negated_operands_simplify"> {
25202520
let patterns = ["BinaryNegatedOperandsSimplify<stablehlo::MulOp>"];
25212521
}
2522+
2523+
def ApplyDotGeneralToSyrkPatterns : EnzymeHLOPatternOp<
2524+
"dot_general_to_syrk"> {
2525+
let patterns = ["DotGeneralToSyrk"];
2526+
}
2527+
2528+
def ApplyTransposeSyrkToSyrkPatterns : EnzymeHLOPatternOp<
2529+
"transpose_syrk_to_syrk"> {
2530+
let patterns = ["TransposeSyrkToSyrk"];
2531+
}
2532+
2533+
def ApplyFuseMulIntoSyrkPatterns : EnzymeHLOPatternOp<
2534+
"fuse_mul_into_syrk"> {
2535+
let patterns = ["FuseMulIntoSyrk"];
2536+
}
2537+
2538+
def ApplyFuseAddIntoSyrkPatterns : EnzymeHLOPatternOp<
2539+
"fuse_add_into_syrk"> {
2540+
let patterns = ["FuseAddIntoSyrk"];
2541+
}

0 commit comments

Comments
 (0)