Skip to content

Commit bc1dae9

Browse files
[MLIR][TORCH] Undo indices conversion to i32 for TMTensor ops (#4292)
Signed-off-by: Vivek Khandelwal <[email protected]> Co-authored-by: Kunwar Grover
1 parent d7e3484 commit bc1dae9

File tree

2 files changed

+40
-20
lines changed

2 files changed

+40
-20
lines changed

lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
8989
// Store location for insertions
9090
Location loc = src.getLoc();
9191

92+
Type indicesElemType = getElementTypeOrSelf(indices);
9293
Value indexSize = getTensorSize(rewriter, loc, indices);
9394
indexSize = castIntToIndex(rewriter, loc, indexSize);
9495
SmallVector<Value> indexShape = getTensorSizes(rewriter, loc, indices);
@@ -97,7 +98,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
9798
// We flatten the `src` values from (i, j, k, ...) -> (i * j * k * ...)
9899
SmallVector<Value> indSliceShape({indexSize, cstOne});
99100
Value indSlice =
100-
createZeroInitTensor(rewriter, loc, indSliceShape, rewriter.getI32Type());
101+
createZeroInitTensor(rewriter, loc, indSliceShape, indicesElemType);
101102

102103
// New output shape will be equal to the product of the dimensions of the
103104
// updates
@@ -142,13 +143,13 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
142143
SmallVector<Value> yieldVals;
143144
for (Value v : indexValues) {
144145
Value scalar = castIndexToInt64(b, loc, v);
145-
yieldVals.push_back(b.create<arith::TruncIOp>(
146-
loc, rewriter.getI32Type(), scalar));
146+
yieldVals.push_back(convertScalarToDtype(
147+
rewriter, loc, scalar, indicesElemType));
147148
}
148149
// Replace the original index with the index specified
149150
// by the scatter.
150151
yieldVals[dim] = convertScalarToDtype(
151-
rewriter, loc, extractIndexValue, rewriter.getI32Type());
152+
rewriter, loc, extractIndexValue, indicesElemType);
152153
yieldVals.push_back(extractSrcValue);
153154
b.create<linalg::YieldOp>(loc, yieldVals);
154155
})
@@ -177,7 +178,7 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter,
177178
rewriter.create<arith::ConstantIndexOp>(loc, indexType.getRank());
178179
Value flattenedIndices = createZeroInitTensor(
179180
rewriter, loc, SmallVector<Value>({indexSize, indicesRank}),
180-
rewriter.getI32Type());
181+
indexType.getElementType());
181182
SmallVector<Value> scatterInputsVector(flattenedUpdates);
182183
for (auto const slice : ArrayRef(scatterInputsVector).drop_back()) {
183184
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, slice);
@@ -540,8 +541,7 @@ class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
540541

541542
// Creating a tm_tensor.scatter op with the following mapping:
542543
// 1.) `input` tensor maps to the indices in scatter op. `input` is
543-
// expanded from 1-d to 2-d, and its element type is set to i32 as required
544-
// for the scatter op.
544+
// expanded from 1-d to 2-d.
545545
// 2.) `updates` is a 1-d dummy tensor with the size equivalent to the
546546
// `input`.
547547
// 3.) `bincount` a 1-d tensor maps to the original in scatter op
@@ -556,12 +556,10 @@ class ConvertAtenBincountOp : public OpConversionPattern<AtenBincountOp> {
556556
Value expandedInputTensor = rewriter.create<AtenUnsqueezeOp>(
557557
loc, expandInputType, torchTypeInput, torchCstOne);
558558

559-
// Converting the input element type to i32.
560-
Value indices = convertTensorToDtype(
561-
rewriter, loc, expandedInputTensor,
562-
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
563-
indices = typeConverter->materializeTargetConversion(
564-
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
559+
Value indices = typeConverter->materializeTargetConversion(
560+
rewriter, loc,
561+
typeConverter->convertType(expandedInputTensor.getType()),
562+
expandedInputTensor);
565563

566564
auto resultType = cast<RankedTensorType>(
567565
typeConverter->convertType(op->getResult(0).getType()));
@@ -1039,7 +1037,6 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp
10391037
return failure();
10401038

10411039
Location loc = op.getLoc();
1042-
MLIRContext *context = op->getContext();
10431040
Value gradOutput = adaptor.getGradOutput();
10441041
Value input = adaptor.getSelf();
10451042
RankedTensorType gradOutputType =
@@ -1049,12 +1046,7 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp
10491046
Type inputElemType = inputType.getElementType();
10501047
int64_t tensorOperandRank = inputType.getRank();
10511048

1052-
// `TMTensor::ScatterOp` expects indices of element type i32.
1053-
Value indices = convertTensorToDtype(
1054-
rewriter, loc, op.getIndices(),
1055-
mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed));
1056-
indices = typeConverter->materializeTargetConversion(
1057-
rewriter, loc, typeConverter->convertType(indices.getType()), indices);
1049+
Value indices = adaptor.getIndices();
10581050
RankedTensorType indicesType = cast<RankedTensorType>(indices.getType());
10591051
Type indicesElemType = indicesType.getElementType();
10601052

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-tmtensor -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// -----
4+
5+
// CHECK-LABEL: @scatter_src_i64_index
6+
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0, 1, 2>} unique_indices(false) ins(%{{.*}}, %{{.*}} : tensor<?xf32>, tensor<?x3xi64>) outs(%{{.*}} : tensor<10x8x6xf32>) {
7+
// CHECK: ^bb0(%arg3: f32, %arg4: f32):
8+
// CHECK: tm_tensor.yield %arg3 : f32
9+
// CHECK: } -> tensor<10x8x6xf32>
10+
func.func @scatter_src_i64_index(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si64>, %arg2: !torch.vtensor<[5,8,6],f32>) -> !torch.vtensor<[10,8,6],f32> {
11+
%int0 = torch.constant.int 0
12+
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si64>, !torch.vtensor<[5,8,6],f32> -> !torch.vtensor<[10,8,6],f32>
13+
return %0 : !torch.vtensor<[10,8,6],f32>
14+
}
15+
16+
17+
// -----
18+
19+
// CHECK-LABEL: @scatter_src_i32_index
20+
// CHECK: tm_tensor.scatter {dimension_map = array<i64: 0, 1, 2>} unique_indices(false) ins(%{{.*}}, %{{.*}} : tensor<?xf32>, tensor<?x3xi32>) outs(%{{.*}} : tensor<10x8x6xf32>) {
21+
// CHECK: ^bb0(%arg3: f32, %arg4: f32):
22+
// CHECK: tm_tensor.yield %arg3 : f32
23+
// CHECK: } -> tensor<10x8x6xf32>
24+
func.func @scatter_src_i32_index(%arg0: !torch.vtensor<[10,8,6],f32>, %arg1: !torch.vtensor<[2,4,3],si32>, %arg2: !torch.vtensor<[5,8,6],f32>) -> !torch.vtensor<[10,8,6],f32> {
25+
%int0 = torch.constant.int 0
26+
%0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[10,8,6],f32>, !torch.int, !torch.vtensor<[2,4,3],si32>, !torch.vtensor<[5,8,6],f32> -> !torch.vtensor<[10,8,6],f32>
27+
return %0 : !torch.vtensor<[10,8,6],f32>
28+
}

0 commit comments

Comments
 (0)