@@ -3537,6 +3537,30 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
35373537};
35383538} // namespace
35393539
3540+ namespace { // Start of rearrangement ops utility functions
3541+ // Extracts shape as vector of int64_t from vector of Value
3542+ SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3543+ SmallVector<int64_t> shape;
3544+ shape.reserve(vals.size());
3545+ for (Value v : vals) {
3546+ int64_t cst_val;
3547+ if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3548+ shape.push_back(cst_val);
3549+ } else {
3550+ shape.push_back(kUnknownSize);
3551+ }
3552+ }
3553+ return shape;
3554+ }
3555+
3556+ // Converts a vector of Value (shape dimensions) into a ValueTensorType
3557+ ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3558+ SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3559+ return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3560+ inOptionalDType);
3561+ }
3562+ } // namespace
3563+
35403564// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
35413565// prims.collapse operations.
35423566//
@@ -3562,7 +3586,6 @@ class DecomposeAtenPixelShuffleOp
35623586 using OpRewritePattern::OpRewritePattern;
35633587 LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
35643588 PatternRewriter &rewriter) const override {
3565-
35663589 Location loc = op.getLoc();
35673590 Value inValue = op.getSelf();
35683591 auto inType = cast<BaseTensorType>(inValue.getType());
@@ -3585,27 +3608,6 @@ class DecomposeAtenPixelShuffleOp
35853608
35863609 const auto inOptionalDType = inType.getOptionalDtype();
35873610
3588- auto getTypeFromShape = [inOptionalDType](auto &&vals) {
3589- // Get a vector of integers from a vector of Values.
3590- auto getIntShape = [](auto &&vals) {
3591- SmallVector<int64_t> shape;
3592- shape.reserve(vals.size());
3593- for (auto v : vals) {
3594- int64_t cst_val;
3595- if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3596- shape.push_back(cst_val);
3597- } else {
3598- shape.push_back(kUnknownSize);
3599- }
3600- }
3601- return shape;
3602- };
3603-
3604- const auto intShape = getIntShape(vals);
3605- return ValueTensorType::get(vals[0].getContext(),
3606- llvm::ArrayRef(intShape), inOptionalDType);
3607- };
3608-
36093611 auto nLeadingDims = inRank - 3;
36103612
36113613 // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
@@ -3677,24 +3679,24 @@ class DecomposeAtenPixelShuffleOp
36773679 auto partiallyExpanded =
36783680 rewriter
36793681 .create<PrimsSplitDimOp>(
3680- loc, getTypeFromShape(partiallyExpandedShape), inValue ,
3681- dimensionConstants[nLeadingDims], outC)
3682+ loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType) ,
3683+ inValue, dimensionConstants[nLeadingDims], outC)
36823684 .getResult();
36833685
36843686 // Split new dimension factorSquared -> (factor, factor)
36853687 auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3686- loc, getTypeFromShape(prePermuteShape), partiallyExpanded ,
3687- dimensionConstants[nLeadingDims + 1], factor);
3688+ loc, getTypeFromShape(prePermuteShape, inOptionalDType) ,
3689+ partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
36883690
36893691 // Perform the permutation
3690- auto permuted =
3691- rewriter.create<AtenPermuteOp>( loc, getTypeFromShape(postPermuteShape) ,
3692- fullyExpanded, permuteDimsOrder);
3692+ auto permuted = rewriter.create<AtenPermuteOp>(
3693+ loc, getTypeFromShape(postPermuteShape, inOptionalDType), fullyExpanded ,
3694+ permuteDimsOrder);
36933695
36943696 // Collapse final 2 dimension
36953697 auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3696- loc, getTypeFromShape(partiallyCollapsedShape), permuted ,
3697- dimensionConstants[nLeadingDims + 3],
3698+ loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType) ,
3699+ permuted, dimensionConstants[nLeadingDims + 3],
36983700 dimensionConstants[nLeadingDims + 4]);
36993701
37003702 // Collapse back to original rank
@@ -3708,6 +3710,147 @@ class DecomposeAtenPixelShuffleOp
37083710};
37093711} // namespace
37103712
3713+ // Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
3714+ // prims.collapse operations.
3715+ //
3716+ // If input is a tensor of shape
3717+ // (N, g*C, H, W),
3718+ //
3719+ // then
3720+ // X = channel_shuffle(input, groups)
3721+ //
3722+ // gets replaced with
3723+ // X = input.split_dim(...) # shape (N, g, C, *)
3724+ // X = X.permute(0, 2, 1, ...) # shape (N, C, g, *)
3725+ // X = X.collapse(...) # shape (N, C*g, *)
3726+ //
3727+ // 'g' above is referred to as the number of 'groups'. N is the batch
3728+ // dimension, and can't be omitted. In PyTorch's ChannelShuffle operator
3729+ // if the batch dimension is ommitted, the first spatial dimenion is seen
3730+ // as the channel. PyTorch errors out for the code below indicating that
3731+ // 4 is not divisible by 3:
3732+ // input_tensor = torch.arange(1, 37, dtype=torch.float32).view(3, 4, 3)
3733+ // channel_shuffle_layer = nn.ChannelShuffle(groups=3)
3734+ // output_tensor = channel_shuffle_layer(input_tensor)
3735+ //
3736+ // The decomposition is based on this specification:
3737+ // https://pytorch.org/docs/stable/generated/torch.nn.ChannelShuffle.html
3738+ // and PyTorch implementation: aten/src/ATen/native/ChanelShuffle.cpp
3739+ // (yes, the filename is misspelled "Chanel" in upstream PyTorch)
3740+ //
3741+ namespace {
3742+ class DecomposeAtenChannelShuffleOp
3743+ : public OpRewritePattern<AtenChannelShuffleOp> {
3744+ public:
3745+ using OpRewritePattern::OpRewritePattern;
3746+ LogicalResult matchAndRewrite(AtenChannelShuffleOp op,
3747+ PatternRewriter &rewriter) const override {
3748+ Location loc = op.getLoc();
3749+ Value inValue = op.getSelf();
3750+ auto inType = cast<BaseTensorType>(inValue.getType());
3751+ auto maybeSizes = inType.getOptionalSizes();
3752+ if (!maybeSizes) {
3753+ return rewriter.notifyMatchFailure(
3754+ op, "Expected input tensor to have known rank.");
3755+ }
3756+ auto inShape = maybeSizes.value();
3757+ auto inRank = inShape.size();
3758+
3759+ // The input tensor must have at least 3 dimensions: batch size,
3760+ // channel size, and at least one spatial dimension.
3761+ if (inRank < 3)
3762+ return rewriter.notifyMatchFailure(
3763+ op, "Expected input tensor to have rank greater than or equal to 3.");
3764+
3765+ auto numOfSpatialDims = inRank - 2;
3766+
3767+ // Get the size of the dimension 'i'. Note the use of 'createOrFold'
3768+ // instead of 'create': if the dimension size is known, then the
3769+ // AtenSizeIntOp is folded to a ConstantOp.
3770+ auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3771+ Value dim =
3772+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3773+ return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3774+ };
3775+
3776+ // The channel dimension is always the second dimension. PyTorch errors out
3777+ // if the batch dimension (first dimension) is not present. See comment at
3778+ // the top of this class for details.
3779+ auto inC = getDimSize(1);
3780+ SmallVector<Value> inSpatialDims;
3781+ inSpatialDims.reserve(numOfSpatialDims);
3782+ for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3783+ inSpatialDims.push_back(getDimSize(i));
3784+ }
3785+
3786+ auto groups = op.getGroups();
3787+
3788+ // Temporary channel dimension size: tempC = inC / groups
3789+ // Assumes input has been validated: `inC % groups == 0`
3790+ // This is enforced by PyTorch's runtime and is required for correctness.
3791+ Value tempC = rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, groups);
3792+
3793+ // Create constants for split/permute/collapse operations. Note that we
3794+ // need an extra constant for the channel dimension split.
3795+ SmallVector<Value> dimensionConstants;
3796+ dimensionConstants.reserve(inRank + 1);
3797+ for (unsigned i = 0; i < inRank + 1; ++i) {
3798+ dimensionConstants.push_back(
3799+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3800+ }
3801+
3802+ Value batchDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3803+ loc, inValue, dimensionConstants[0]);
3804+
3805+ SmallVector<Value> splitShape;
3806+ splitShape.reserve(inRank + 1);
3807+ splitShape.append({batchDimSize, groups, tempC});
3808+ splitShape.append(inSpatialDims); // Appends all spatial dimensions
3809+
3810+ SmallVector<Value> permuteShape;
3811+ permuteShape.reserve(inRank + 1);
3812+ permuteShape.append({batchDimSize, tempC, groups});
3813+ permuteShape.append(inSpatialDims); // Appends all spatial dimensions
3814+
3815+ // Permute (N, groups, tempC, *) -> (N, tempC, groups, *)
3816+ SmallVector<Value> permutation{dimensionConstants[0], // batch dimension
3817+ dimensionConstants[2], // tempC
3818+ dimensionConstants[1]}; // groups
3819+ for (unsigned i = 3; i < inRank + 1; ++i) {
3820+ permutation.push_back(dimensionConstants[i]);
3821+ }
3822+
3823+ Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3824+ loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3825+ permutation);
3826+
3827+ const auto inOptionalDType = inType.getOptionalDtype();
3828+
3829+ Value dimC = dimensionConstants[1];
3830+ Value dimG = dimensionConstants[2];
3831+
3832+ // Split input channel inC -> (groups, inC/groups)
3833+ auto expandedTensor =
3834+ rewriter
3835+ .create<PrimsSplitDimOp>(
3836+ loc, getTypeFromShape(splitShape, inOptionalDType), inValue,
3837+ dimC, tempC)
3838+ .getResult();
3839+
3840+ // Perform the permutation
3841+ auto permuted = rewriter.create<AtenPermuteOp>(
3842+ loc, getTypeFromShape(permuteShape, inOptionalDType), expandedTensor,
3843+ permuteDimsOrder);
3844+
3845+ // Collapse (C, groups) back into a single channel dimension
3846+ rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
3847+ dimC, dimG);
3848+
3849+ return success();
3850+ }
3851+ };
3852+ } // namespace
3853+
37113854// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
37123855static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
37133856 Value input) {
@@ -12518,6 +12661,7 @@ class DecomposeComplexOpsPass
1251812661 addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1251912662 addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1252012663 addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
12664+ addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
1252112665 addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1252212666 addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
1252312667 patterns);
0 commit comments