|
13 | 13 | #include "mlir/Support/LLVM.h" |
14 | 14 | #include "llvm/Support/raw_ostream.h" |
15 | 15 |
|
| 16 | +#include "stablehlo/dialect/StablehloOps.h" |
| 17 | + |
16 | 18 | using namespace mlir; |
17 | 19 | using namespace mlir::dataflow; |
18 | 20 |
|
@@ -75,9 +77,6 @@ void StructuredSparsityPattern::initializeBandwidths() { |
75 | 77 | } |
76 | 78 |
|
77 | 79 | void StructuredSparsityPattern::refineKind() { |
78 | | - if (kind != StructuredSparsityKind::Band) |
79 | | - return; |
80 | | - |
81 | 80 | if (lowerBandwidth == 0) { |
82 | 81 | if (upperBandwidth == 0) { |
83 | 82 | kind = StructuredSparsityKind::Diagonal; |
@@ -215,15 +214,26 @@ ValueProperties::ValueProperties(Value v) { |
215 | 214 | if (matchPattern(v, m_Constant(&denseAttr))) { |
216 | 215 | auto props = getPropertiesFromDenseAttr(denseAttr); |
217 | 216 | setFlags(props.getFlags()); |
218 | | - llvm::errs() << "v: " << v << " properties: "; |
219 | | - this->print(llvm::errs()); |
220 | | - llvm::errs() << "\n"; |
221 | 217 | return; |
222 | 218 | } |
223 | 219 |
|
224 | | - // TODO: symmetric checks Utils.cpp:688 |
| 220 | + auto defOp = v.getDefiningOp(); |
| 221 | + if (!defOp) |
| 222 | + return; |
| 223 | + |
| 224 | + // comm_op(A, A^T) will always be symmetric |
225 | 225 |
|
226 | | - // TODO: broadcasted scalar |
| 226 | + // A x A^T will always be symmetric |
| 227 | + |
| 228 | + if (auto bcastOp = dyn_cast<stablehlo::BroadcastInDimOp>(defOp)) { |
| 229 | + auto operand = bcastOp.getOperand(); |
| 230 | + if (cast<RankedTensorType>(operand.getType()).getRank() == 0) { // bcast(scalar) |
| 231 | + if (matchPattern(operand, m_One())) // bcast(1) |
| 232 | + set(ValueProperty::UnitDiagonal); |
| 233 | + set(ValueProperty::BroadcastedScalar); |
| 234 | + set(ValueProperty::Symmetric); |
| 235 | + } |
| 236 | + } |
227 | 237 |
|
228 | 238 | // TODO: unit diagonal |
229 | 239 | // - iota scatter with constant |
@@ -444,9 +454,5 @@ LogicalResult StructuredMatrixAnalysis::visitOperation( |
444 | 454 | return success(); |
445 | 455 | } |
446 | 456 |
|
447 | | -//===----------------------------------------------------------------------===// |
448 | | -// Structure Originators |
449 | | -//===----------------------------------------------------------------------===// |
450 | | - |
451 | 457 | } // namespace structure_analysis |
452 | 458 | } // namespace mlir |
0 commit comments