Skip to content

Commit e3cd107

Browse files
committed
Move boolAttr checks in each analysis to shared check in localGuaranteedWithSetAttr
1 parent b4a4313 commit e3cd107

File tree

2 files changed

+11
-22
lines changed

2 files changed

+11
-22
lines changed

src/enzyme_ad/jax/Utils.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -748,13 +748,6 @@ NoNanResultAnalysis::localGuaranteed(Operation *op,
748748
PatternRewriter &rewriter) {
749749
assert(op);
750750

751-
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
752-
if (boolAttr.getValue())
753-
return State::GUARANTEED;
754-
else
755-
return State::NOTGUARANTEED;
756-
}
757-
758751
DenseElementsAttr denseAttr;
759752
if (matchPattern(op, m_Constant(&denseAttr))) {
760753
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
@@ -891,13 +884,6 @@ FiniteResultAnalysis::localGuaranteed(Operation *op,
891884
PatternRewriter &rewriter) {
892885
assert(op);
893886

894-
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
895-
if (boolAttr.getValue())
896-
return State::GUARANTEED;
897-
else
898-
return State::NOTGUARANTEED;
899-
}
900-
901887
DenseElementsAttr denseAttr;
902888
if (matchPattern(op, m_Constant(&denseAttr))) {
903889
if (guaranteedConstantOp(op, denseAttr, rewriter)) {
@@ -1004,13 +990,6 @@ NonNegativeResultAnalysis::State NonNegativeResultAnalysis::localGuaranteed(
1004990
PatternRewriter &rewriter) {
1005991
assert(op);
1006992

1007-
if (auto boolAttr = op->getAttrOfType<BoolAttr>(getAttrName())) {
1008-
if (boolAttr.getValue())
1009-
return State::GUARANTEED;
1010-
else
1011-
return State::NOTGUARANTEED;
1012-
}
1013-
1014993
DenseElementsAttr denseAttr;
1015994
if (matchPattern(op, m_Constant(&denseAttr))) {
1016995
if (guaranteedConstantOp(op, denseAttr, rewriter)) {

src/enzyme_ad/jax/Utils.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -541,8 +541,18 @@ template <typename Child> class GuaranteedResultAnalysisBase {
541541
State localGuaranteedWithSetAttr(Operation *op,
542542
SmallVectorImpl<Operation *> &localtodo,
543543
PatternRewriter &rewriter) {
544-
auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);
544+
545545
auto attrName = ((Child *)this)->getAttrName();
546+
547+
if (auto boolAttr = op->getAttrOfType<BoolAttr>(attrName)) {
548+
if (boolAttr.getValue())
549+
return State::GUARANTEED;
550+
else
551+
return State::NOTGUARANTEED;
552+
}
553+
554+
auto state = ((Child *)this)->localGuaranteed(op, localtodo, rewriter);
555+
546556
switch (state) {
547557
case State::GUARANTEED:
548558
rewriter.modifyOpInPlace(op, [&]() {

0 commit comments

Comments
 (0)