Skip to content

Commit 05720e3

Browse files
committed
Revert "Use dyn_cast for safety in casting RankedTensorType"
This reverts commit 662e48e.
1 parent a4954ce commit 05720e3

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -605,15 +605,15 @@ SymmetricResultAnalysis::State SymmetricResultAnalysis::localGuaranteed(
605605
PatternRewriter &rewriter) {
606606
assert(op);
607607

608+
auto outTy = cast<RankedTensorType>(op->getResult(0).getType());
609+
if (outTy.getRank() != 2)
608610
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
609611
if (boolAttr.getValue())
610612
return State::GUARANTEED;
611613
else
612614
return State::NOTGUARANTEED;
613615
}
614616

615-
auto outTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
616-
if (!outTy || outTy.getRank() != 2)
617617
return State::NOTGUARANTEED; // this pass only checks for symmetric matrices
618618
if (outTy.getDimSize(0) != outTy.getDimSize(1))
619619
return State::NOTGUARANTEED; // quick check and exit

0 commit comments

Comments
 (0)