@@ -89,6 +89,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
8989 // Store location for insertions
9090 Location loc = src.getLoc ();
9191
92+ Type indicesElemType = getElementTypeOrSelf (indices);
9293 Value indexSize = getTensorSize (rewriter, loc, indices);
9394 indexSize = castIntToIndex (rewriter, loc, indexSize);
9495 SmallVector<Value> indexShape = getTensorSizes (rewriter, loc, indices);
@@ -97,7 +98,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
9798 // We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...)
9899 SmallVector<Value> indSliceShape ({indexSize, cstOne});
99100 Value indSlice =
100- createZeroInitTensor (rewriter, loc, indSliceShape, rewriter. getI32Type () );
101+ createZeroInitTensor (rewriter, loc, indSliceShape, indicesElemType );
101102
102103 // New output shape will be equal to the product of the dimensions of the
103104 // updates
@@ -142,13 +143,13 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
142143 SmallVector<Value> yieldVals;
143144 for (Value v : indexValues) {
144145 Value scalar = castIndexToInt64 (b, loc, v);
145- yieldVals.push_back (b. create <arith::TruncIOp> (
146- loc, rewriter. getI32Type () , scalar));
146+ yieldVals.push_back (convertScalarToDtype (
147+ rewriter, loc , scalar, indicesElemType ));
147148 }
148149 // Replace the original index with the index specified
149150 // by the scatter.
150151 yieldVals[dim] = convertScalarToDtype (
151- rewriter, loc, extractIndexValue, rewriter. getI32Type () );
152+ rewriter, loc, extractIndexValue, indicesElemType );
152153 yieldVals.push_back (extractSrcValue);
153154 b.create <linalg::YieldOp>(loc, yieldVals);
154155 })
@@ -177,7 +178,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
177178 rewriter.create <arith::ConstantIndexOp>(loc, indexType.getRank ());
178179 Value flattenedIndices = createZeroInitTensor (
179180 rewriter, loc, SmallVector<Value>({indexSize, indicesRank}),
180- rewriter. getI32Type ());
181+ indexType. getElementType ());
181182 SmallVector<Value> scatterInputsVector (flattenedUpdates);
182183 for (auto const slice : ArrayRef (scatterInputsVector).drop_back ()) {
183184 SmallVector<Value> sizes = getTensorSizes (rewriter, loc, slice);
@@ -540,8 +541,7 @@ class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
540541
541542 // Creating a tm_tensor.scatter op with the following mapping:
542543 // 1.) `input` tensor maps to the indices in scatter op. `input` is
543- // expanded from 1-d to 2-d, and its element type is set to i32 as required
544- // for the scatter op.
544+ // expanded from 1-d to 2-d.
545545 // 2.) `updates` is a 1-d dummy tensor with the size equivalent to the
546546 // `input`.
547547 // 3.) `bincount` a 1-d tensor maps to the original in scatter op
@@ -556,12 +556,10 @@ class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
556556 Value expandedInputTensor = rewriter.create <AtenUnsqueezeOp>(
557557 loc, expandInputType, torchTypeInput, torchCstOne);
558558
559- // Converting the input element type to i32.
560- Value indices = convertTensorToDtype (
561- rewriter, loc, expandedInputTensor,
562- mlir::IntegerType::get (context, 32 , mlir::IntegerType::Signed));
563- indices = typeConverter->materializeTargetConversion (
564- rewriter, loc, typeConverter->convertType (indices.getType ()), indices);
559+ Value indices = typeConverter->materializeTargetConversion (
560+ rewriter, loc,
561+ typeConverter->convertType (expandedInputTensor.getType ()),
562+ expandedInputTensor);
565563
566564 auto resultType = cast<RankedTensorType>(
567565 typeConverter->convertType (op->getResult (0 ).getType ()));
@@ -1039,7 +1037,6 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp
10391037 return failure ();
10401038
10411039 Location loc = op.getLoc ();
1042- MLIRContext *context = op->getContext ();
10431040 Value gradOutput = adaptor.getGradOutput ();
10441041 Value input = adaptor.getSelf ();
10451042 RankedTensorType gradOutputType =
@@ -1049,12 +1046,7 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp
10491046 Type inputElemType = inputType.getElementType ();
10501047 int64_t tensorOperandRank = inputType.getRank ();
10511048
1052- // `TMTensor::ScatterOp` expects indices of element type i32.
1053- Value indices = convertTensorToDtype (
1054- rewriter, loc, op.getIndices (),
1055- mlir::IntegerType::get (context, 32 , mlir::IntegerType::Signed));
1056- indices = typeConverter->materializeTargetConversion (
1057- rewriter, loc, typeConverter->convertType (indices.getType ()), indices);
1049+ Value indices = adaptor.getIndices ();
10581050 RankedTensorType indicesType = cast<RankedTensorType>(indices.getType ());
10591051 Type indicesElemType = indicesType.getElementType ();
10601052
0 commit comments