@@ -45,85 +45,68 @@ static SmallVector<int64_t> getReduceOutputShape(ArrayRef<int64_t> inputShape,
4545static Value createInitialValueForReduceOp (Operation *op, Type elementTy,
4646 PatternRewriter &rewriter) {
4747 auto constType = RankedTensorType::get ({}, elementTy);
48+ DenseElementsAttr constAttr = nullptr ;
4849 if (isa<AtenSumOp, AtenSumDimIntListOp, AtenFrobeniusNormDimOp,
4950 AtenLinalgVectorNormOp>(op)) {
5051 if (isa<mlir::FloatType>(elementTy)) {
51- auto constAttr = DenseElementsAttr::get (
52+ constAttr = DenseElementsAttr::get (
5253 constType, {APFloat::getZero (
5354 cast<mlir::FloatType>(elementTy).getFloatSemantics (),
5455 /* negative=*/ false )});
55- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
56- constAttr);
5756 } else if (isa<mlir::IntegerType>(elementTy)) {
58- auto constAttr = DenseElementsAttr::get (
57+ constAttr = DenseElementsAttr::get (
5958 constType, {APInt::getZero (elementTy.getIntOrFloatBitWidth ())});
60- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
61- constAttr);
6259 }
6360 }
6461
6562 if (isa<AtenAmaxOp, AtenMaxOp, AtenMaxDimOp, AtenArgmaxOp>(op)) {
6663 if (isa<mlir::FloatType>(elementTy)) {
67- auto constAttr = DenseElementsAttr::get (
64+ constAttr = DenseElementsAttr::get (
6865 constType,
6966 {APFloat::getInf (cast<mlir::FloatType>(elementTy).getFloatSemantics (),
7067 /* negative=*/ true )});
71- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
72- constAttr);
7368 } else if (isa<mlir::IntegerType>(elementTy)) {
74- auto constAttr = DenseElementsAttr::get (
69+ constAttr = DenseElementsAttr::get (
7570 constType,
7671 {APInt::getSignedMinValue (elementTy.getIntOrFloatBitWidth ())});
77- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
78- constAttr);
7972 }
8073 }
8174
8275 if (isa<AtenAminOp, AtenMinOp, AtenMinDimOp, AtenArgminOp>(op)) {
8376 if (isa<mlir::FloatType>(elementTy)) {
84- auto constAttr = DenseElementsAttr::get (
77+ constAttr = DenseElementsAttr::get (
8578 constType,
8679 {APFloat::getInf (cast<mlir::FloatType>(elementTy).getFloatSemantics (),
8780 /* negative=*/ false )});
88- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
89- constAttr);
9081 } else if (isa<mlir::IntegerType>(elementTy)) {
91- auto constAttr = DenseElementsAttr::get (
82+ constAttr = DenseElementsAttr::get (
9283 constType,
9384 {APInt::getSignedMaxValue (elementTy.getIntOrFloatBitWidth ())});
94- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
95- constAttr);
9685 }
9786 }
9887
9988 if (isa<AtenProdOp, AtenProdDimIntOp>(op)) {
10089 if (isa<mlir::FloatType>(elementTy)) {
10190 APFloat one (cast<mlir::FloatType>(elementTy).getFloatSemantics (), 1 );
102- auto constAttr = DenseElementsAttr::get (constType, one);
103- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
104- constAttr);
91+ constAttr = DenseElementsAttr::get (constType, one);
10592 } else if (isa<mlir::IntegerType>(elementTy)) {
10693 APInt one (elementTy.getIntOrFloatBitWidth (), 1 );
107- auto constAttr = DenseElementsAttr::get (constType, one);
108- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
109- constAttr);
94+ constAttr = DenseElementsAttr::get (constType, one);
11095 }
11196 }
11297
11398 if (isa<AtenAllOp, AtenAllDimOp>(op)) {
114- auto constAttr =
115- DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 1 )});
116- return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
117- constAttr);
99+ constAttr = DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 1 )});
118100 }
119101
120102 if (isa<AtenAnyOp, AtenAnyDimOp, AtenAnyDimsOp>(op)) {
121- auto constAttr =
122- DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
103+ constAttr = DenseElementsAttr::get (constType, {APInt (/* numBits=*/ 1 , 0 )});
104+ }
105+
106+ if (constAttr != nullptr ) {
123107 return rewriter.create <stablehlo::ConstantOp>(op->getLoc (), constType,
124108 constAttr);
125109 }
126-
127110 op->emitError (" unimplemented lowering in "
128111 " createInitialValueForReduceOp" );
129112 return nullptr ;
@@ -483,7 +466,7 @@ class ConvertAtenReduceDimsOp : public ConvertAtenReductionOp<AtenOpT> {
483466 return rewriter.notifyMatchFailure (
484467 op, " non-const integer `dim` is not supported" );
485468 }
486- if (inputDims.size () == 0 ) {
469+ if (inputDims.empty () ) {
487470 dims = llvm::to_vector (llvm::seq<int64_t >(0 , inputTy.getRank ()));
488471 } else {
489472 for (auto d : inputDims) {
@@ -570,7 +553,7 @@ class ConvertAtenReduceWithIndicesOp : public ConvertAtenReductionOp<AtenOpT> {
570553 return rewriter.notifyMatchFailure (
571554 op, " failed to get dimension sizes of the input" );
572555 }
573- auto inputShapeVec = *inputShapeInfo;
556+ auto & inputShapeVec = *inputShapeInfo;
574557
575558 if (op.getResult (1 ).use_empty ()) {
576559 llvm::SmallVector<int64_t > outputShape (inputTy.getShape ());
@@ -643,7 +626,7 @@ LogicalResult ConvertAtenReductionOp<AtenAnyDimsOp>::matchAndRewrite(
643626 return rewriter.notifyMatchFailure (
644627 op, " non-const integer `dim` is not supported" );
645628 }
646- if (inputDims.size () == 0 ) {
629+ if (inputDims.empty () ) {
647630 rewriter.replaceOp (op, input);
648631 return success ();
649632 }
@@ -722,7 +705,7 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
722705 return rewriter.notifyMatchFailure (
723706 op, " non-const integer `dim` is not supported" );
724707 }
725- if (inputDims.size () == 0 ) {
708+ if (inputDims.empty () ) {
726709 inputDims = llvm::to_vector<4 >(llvm::seq<int64_t >(0 , inputTy.getRank ()));
727710 }
728711 }
0 commit comments