Skip to content

Commit d49cf52

Browse files
committed
Use dyn_cast for safety in casting RankedTensorType
1 parent e3c976f commit d49cf52

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,15 +639,15 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
639639
PatternRewriter &rewriter) {
640640
assert(op);
641641

642-
auto outTy = cast<RankedTensorType>(op->getResult(0).getType());
643-
if (outTy.getRank() != 2)
644642
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
645643
if (boolAttr.getValue())
646644
return State::GUARANTEED;
647645
else
648646
return State::NOTGUARANTEED;
649647
}
650648

649+
auto outTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
650+
if (!outTy || outTy.getRank() != 2)
651651
return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
652652
if (outTy.getDimSize(0) != outTy.getDimSize(1))
653653
return State::NOTGUARANTEED; // quick check and exit

0 commit comments

Comments
 (0)