Skip to content

Commit 6f8a6b0

Browse files
committed
feat: broadcast check
1 parent 47917cc commit 6f8a6b0

File tree

3 files changed

+18
-16
lines changed

3 files changed

+18
-16
lines changed

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "mlir/Support/LLVM.h"
1414
#include "llvm/Support/raw_ostream.h"
1515

16+
#include "stablehlo/dialect/StablehloOps.h"
17+
1618
using namespace mlir;
1719
using namespace mlir::dataflow;
1820

@@ -75,9 +77,6 @@ void StructuredSparsityPattern::initializeBandwidths() {
7577
}
7678

7779
void StructuredSparsityPattern::refineKind() {
78-
if (kind != StructuredSparsityKind::Band)
79-
return;
80-
8180
if (lowerBandwidth == 0) {
8281
if (upperBandwidth == 0) {
8382
kind = StructuredSparsityKind::Diagonal;
@@ -215,15 +214,26 @@ ValueProperties::ValueProperties(Value v) {
215214
if (matchPattern(v, m_Constant(&denseAttr))) {
216215
auto props = getPropertiesFromDenseAttr(denseAttr);
217216
setFlags(props.getFlags());
218-
llvm::errs() << "v: " << v << " properties: ";
219-
this->print(llvm::errs());
220-
llvm::errs() << "\n";
221217
return;
222218
}
223219

224-
// TODO: symmetric checks Utils.cpp:688
220+
auto defOp = v.getDefiningOp();
221+
if (!defOp)
222+
return;
223+
224+
// comm_op(A, A^T) will always be symmetric
225225

226-
// TODO: broadcasted scalar
226+
// A x A^T will always be symmetric
227+
228+
if (auto bcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(defOp)) {
229+
auto operand = bcastOp.getOperand();
230+
if (cast<RankedTensorType>(operand.getType()).getRank() == 0) { // bcast(scalar)
231+
if (matchPattern(operand, m_One())) // bcast(1)
232+
set(ValueProperty::UnitDiagonal);
233+
set(ValueProperty::BroadcastedScalar);
234+
set(ValueProperty::Symmetric);
235+
}
236+
}
227237

228238
// TODO: unit diagonal
229239
// - iota scatter with constant
@@ -444,9 +454,5 @@ LogicalResult StructuredMatrixAnalysis::visitOperation(
444454
return success();
445455
}
446456

447-
//===----------------------------------------------------------------------===//
448-
// Structure Originators
449-
//===----------------------------------------------------------------------===//
450-
451457
} // namespace structure_analysis
452458
} // namespace mlir

src/enzyme_ad/jax/Analysis/StructuredMatrixAnalysis.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,5 @@ class StructuredMatrixAnalysis
254254
ArrayRef<StructuredMatrixLattice *> results) override;
255255
};
256256

257-
//===----------------------------------------------------------------------===//
258-
// Structure Originators
259-
//===----------------------------------------------------------------------===//
260-
261257
} // namespace structure_analysis
262258
} // namespace mlir

0 commit comments

Comments
 (0)