@@ -3536,30 +3536,6 @@ class DecomposeAten_LinalgDetOp : public OpRewritePattern<Aten_LinalgDetOp> {
35363536};
35373537} // namespace
35383538
3539- namespace { // Start of rearrangement ops utility functions
3540- // Extracts shape as vector of int64_t from vector of Value
3541- SmallVector<int64_t> getIntShapeFromValues(ArrayRef<Value> vals) {
3542- SmallVector<int64_t> shape;
3543- shape.reserve(vals.size());
3544- for (Value v : vals) {
3545- int64_t cst_val;
3546- if (matchPattern(v, m_TorchConstantInt(&cst_val))) {
3547- shape.push_back(cst_val);
3548- } else {
3549- shape.push_back(kUnknownSize);
3550- }
3551- }
3552- return shape;
3553- }
3554-
3555- // Converts a vector of Value (shape dimensions) into a ValueTensorType
3556- ValueTensorType getTypeFromShape(ArrayRef<Value> vals, Type inOptionalDType) {
3557- SmallVector<int64_t> intShape = getIntShapeFromValues(vals);
3558- return ValueTensorType::get(vals[0].getContext(), llvm::ArrayRef(intShape),
3559- inOptionalDType);
3560- }
3561- } // namespace
3562-
35633539// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
35643540// prims.collapse operations.
35653541//
@@ -3609,18 +3585,9 @@ class DecomposeAtenPixelShuffleOp
36093585
36103586 auto nLeadingDims = inRank - 3;
36113587
3612- // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
3613- // of 'create': if the dimension size is known, then the AtenSizeIntOp is
3614- // folded to a ConstantOp.
3615- auto getDimSize = [&](uint64_t i) -> Value {
3616- Value dim =
3617- rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3618- return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3619- };
3620-
3621- auto inC = getDimSize(inRank - 3);
3622- auto inH = getDimSize(inRank - 2);
3623- auto inW = getDimSize(inRank - 1);
3588+ auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3589+ auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3590+ auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
36243591
36253592 auto factor = op.getUpscaleFactor();
36263593
@@ -3678,23 +3645,26 @@ class DecomposeAtenPixelShuffleOp
36783645 auto partiallyExpanded =
36793646 rewriter
36803647 .create<PrimsSplitDimOp>(
3681- loc, getTypeFromShape(partiallyExpandedShape, inOptionalDType),
3648+ loc,
3649+ getTensorTypeFromShapeValues(partiallyExpandedShape,
3650+ inOptionalDType),
36823651 inValue, dimensionConstants[nLeadingDims], outC)
36833652 .getResult();
36843653
36853654 // Split new dimension factorSquared -> (factor, factor)
36863655 auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3687- loc, getTypeFromShape (prePermuteShape, inOptionalDType),
3656+ loc, getTensorTypeFromShapeValues (prePermuteShape, inOptionalDType),
36883657 partiallyExpanded, dimensionConstants[nLeadingDims + 1], factor);
36893658
36903659 // Perform the permutation
36913660 auto permuted = rewriter.create<AtenPermuteOp>(
3692- loc, getTypeFromShape (postPermuteShape, inOptionalDType), fullyExpanded ,
3693- permuteDimsOrder);
3661+ loc, getTensorTypeFromShapeValues (postPermuteShape, inOptionalDType),
3662+ fullyExpanded, permuteDimsOrder);
36943663
36953664 // Collapse final 2 dimension
36963665 auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
3697- loc, getTypeFromShape(partiallyCollapsedShape, inOptionalDType),
3666+ loc,
3667+ getTensorTypeFromShapeValues(partiallyCollapsedShape, inOptionalDType),
36983668 permuted, dimensionConstants[nLeadingDims + 3],
36993669 dimensionConstants[nLeadingDims + 4]);
37003670
@@ -3709,6 +3679,142 @@ class DecomposeAtenPixelShuffleOp
37093679};
37103680} // namespace
37113681
3682+ // Decompose aten.pixel_unshuffle into: prims.split_dim, aten.permute, and
3683+ // prims.collapse operations.
3684+ //
3685+ // We want to do the exact opposite of aten.pixel_shuffle
3686+ //
3687+ // 'r' is referred to as the 'downscale factor' or just 'factor' below.
3688+ //
3689+ // If input is a tensor of shape
3690+ // (*leading_dims, C, H*r, W*r),
3691+ //
3692+ // where leading_dims is of size N, then
3693+ // X = pixel_unshuffle(input, downscale_factor)
3694+ //
3695+ // gets replaced with
3696+ // X = input.split_dim(...) # shape (*leading_dims, C, H, r, W*r)
3697+ // X = X.split_dim(...) # shape (*leading_dims, C, H, r, W, r)
3698+ // X = X.permute(0, ..., N, N+2, N+4, N+1, N+3)
3699+ // # shape (*leading_dims, C, r, r, H, W)
3700+ // X = X.collapse(...) # shape (*leading_dims, C*r*r, H, W)
3701+ //
3702+ namespace {
3703+ class DecomposeAtenPixelUnshuffleOp
3704+ : public OpRewritePattern<AtenPixelUnshuffleOp> {
3705+ public:
3706+ using OpRewritePattern::OpRewritePattern;
3707+ LogicalResult matchAndRewrite(AtenPixelUnshuffleOp op,
3708+ PatternRewriter &rewriter) const override {
3709+
3710+ Location loc = op.getLoc();
3711+ Value inValue = op.getSelf();
3712+ auto inType = cast<BaseTensorType>(inValue.getType());
3713+ auto maybeSizes = inType.getOptionalSizes();
3714+ if (!maybeSizes) {
3715+ return rewriter.notifyMatchFailure(
3716+ op, "Expected input tensor to have known rank.");
3717+ }
3718+ auto inShape = maybeSizes.value();
3719+ auto inRank = inShape.size();
3720+
3721+ // The input tensor must have at least 3 dimensions: (1) the channel
3722+ // dimension which gets bigger by 'factor*factor', (2) the H channel which
3723+ // gets smaller by 'factor' and (3) the W channel which get smaller by
3724+ // 'factor'. The total number of dimensions is 3 + N, where N is the number
3725+ // of leading dimensions, and N >= 0 so the input must have rank at least 3.
3726+ if (inRank < 3)
3727+ return rewriter.notifyMatchFailure(
3728+ op, "Expected input tensor to have rank greater than 2.");
3729+
3730+ const auto inOptionalDType = inType.getOptionalDtype();
3731+
3732+ auto nLeadingDims = inRank - 3;
3733+
3734+ auto inC = getTensorDimSize(rewriter, inValue, inRank - 3);
3735+ auto inH = getTensorDimSize(rewriter, inValue, inRank - 2);
3736+ auto inW = getTensorDimSize(rewriter, inValue, inRank - 1);
3737+
3738+ auto factor = op.getDownscaleFactor();
3739+
3740+ Value factorSquared =
3741+ rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
3742+
3743+ Value outC = rewriter.createOrFold<AtenMulIntOp>(loc, inC, factorSquared);
3744+
3745+ Value outH = rewriter.createOrFold<AtenFloordivIntOp>(loc, inH, factor);
3746+ Value outW = rewriter.createOrFold<AtenFloordivIntOp>(loc, inW, factor);
3747+
3748+ SmallVector<Value> dimensionConstants;
3749+ dimensionConstants.reserve(inRank + 2);
3750+ for (unsigned i = 0; i < inRank + 2; ++i) {
3751+ dimensionConstants.push_back(
3752+ rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
3753+ }
3754+
3755+ SmallVector<Value> leadingDims;
3756+ leadingDims.reserve(nLeadingDims);
3757+ for (unsigned i = 0; i < nLeadingDims; ++i) {
3758+ Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
3759+ loc, inValue, dimensionConstants[i]);
3760+ leadingDims.push_back(leadingDimSize);
3761+ }
3762+
3763+ SmallVector<Value> prePermuteShape = leadingDims;
3764+ prePermuteShape.append({inC, outH, factor, outW, factor});
3765+
3766+ SmallVector<Value> postPermuteShape = leadingDims;
3767+ postPermuteShape.append({inC, factor, factor, outH, outW});
3768+
3769+ SmallVector<Value> partiallyCollapsedShape = leadingDims;
3770+ partiallyCollapsedShape.append({inC, factorSquared, outH, outW});
3771+
3772+ SmallVector<Value> outShape = leadingDims;
3773+ outShape.append({outC, outH, outW});
3774+
3775+ SmallVector<Value> permutation{dimensionConstants.begin(),
3776+ dimensionConstants.begin() + nLeadingDims};
3777+ SmallVector<uint64_t> permutationTail{0, 2, 4, 1, 3};
3778+ for (uint64_t d : permutationTail) {
3779+ permutation.push_back(dimensionConstants[nLeadingDims + d]);
3780+ }
3781+
3782+ Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
3783+ loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
3784+ permutation);
3785+
3786+ SmallVector<Value> heightSplitShape = leadingDims;
3787+ heightSplitShape.append({inC, outH, factor, inW});
3788+
3789+ // Split input channel inH -> (outH, factor)
3790+ auto partiallyExpanded =
3791+ rewriter
3792+ .create<PrimsSplitDimOp>(
3793+ loc,
3794+ getTensorTypeFromShapeValues(heightSplitShape, inOptionalDType),
3795+ inValue, dimensionConstants[nLeadingDims + 1], outH)
3796+ .getResult();
3797+
3798+ // Split new dimension inW -> (outW, factor)
3799+ auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
3800+ loc, getTensorTypeFromShapeValues(prePermuteShape, inOptionalDType),
3801+ partiallyExpanded, dimensionConstants[nLeadingDims + 3], outW);
3802+
3803+ // Perform the permutation
3804+ auto permuted = rewriter.create<AtenPermuteOp>(
3805+ loc, getTensorTypeFromShapeValues(postPermuteShape, inOptionalDType),
3806+ fullyExpanded, permuteDimsOrder);
3807+
3808+ // Collapse final 2 dimensions back to original rank
3809+ rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
3810+ op, op.getType(), permuted, dimensionConstants[nLeadingDims],
3811+ dimensionConstants[nLeadingDims + 2]);
3812+
3813+ return success();
3814+ }
3815+ };
3816+ } // namespace
3817+
37123818// Decompose aten.channel_shuffle into: prims.split_dim, aten.permute, and
37133819// prims.collapse operations.
37143820//
@@ -3763,23 +3869,14 @@ class DecomposeAtenChannelShuffleOp
37633869
37643870 auto numOfSpatialDims = inRank - 2;
37653871
3766- // Get the size of the dimension 'i'. Note the use of 'createOrFold'
3767- // instead of 'create': if the dimension size is known, then the
3768- // AtenSizeIntOp is folded to a ConstantOp.
3769- auto getDimSize = [&rewriter, &inValue, loc](uint64_t i) -> Value {
3770- Value dim =
3771- rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
3772- return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
3773- };
3774-
37753872 // The channel dimension is always the second dimension. PyTorch errors out
37763873 // if the batch dimension (first dimension) is not present. See comment at
37773874 // the top of this class for details.
3778- auto inC = getDimSize( 1);
3875+ auto inC = getTensorDimSize(rewriter, inValue, 1);
37793876 SmallVector<Value> inSpatialDims;
37803877 inSpatialDims.reserve(numOfSpatialDims);
37813878 for (unsigned i = 2; i < (2 + numOfSpatialDims); ++i) {
3782- inSpatialDims.push_back(getDimSize( i));
3879+ inSpatialDims.push_back(getTensorDimSize(rewriter, inValue, i));
37833880 }
37843881
37853882 auto groups = op.getGroups();
@@ -3832,14 +3929,14 @@ class DecomposeAtenChannelShuffleOp
38323929 auto expandedTensor =
38333930 rewriter
38343931 .create<PrimsSplitDimOp>(
3835- loc, getTypeFromShape (splitShape, inOptionalDType), inValue ,
3836- dimC, tempC)
3932+ loc, getTensorTypeFromShapeValues (splitShape, inOptionalDType),
3933+ inValue, dimC, tempC)
38373934 .getResult();
38383935
38393936 // Perform the permutation
38403937 auto permuted = rewriter.create<AtenPermuteOp>(
3841- loc, getTypeFromShape (permuteShape, inOptionalDType), expandedTensor ,
3842- permuteDimsOrder);
3938+ loc, getTensorTypeFromShapeValues (permuteShape, inOptionalDType),
3939+ expandedTensor, permuteDimsOrder);
38433940
38443941 // Collapse (C, groups) back into a single channel dimension
38453942 rewriter.replaceOpWithNewOp<PrimsCollapseOp>(op, op.getType(), permuted,
@@ -12909,6 +13006,7 @@ class DecomposeComplexOpsPass
1290913006 addPatternIfTargetOpIsIllegal<DecomposeAtenRenormOp>(patterns);
1291013007 addPatternIfTargetOpIsIllegal<DecomposeAtenLinalgCrossOp>(patterns);
1291113008 addPatternIfTargetOpIsIllegal<DecomposeAtenPixelShuffleOp>(patterns);
13009+ addPatternIfTargetOpIsIllegal<DecomposeAtenPixelUnshuffleOp>(patterns);
1291213010 addPatternIfTargetOpIsIllegal<DecomposeAtenChannelShuffleOp>(patterns);
1291313011 addPatternIfTargetOpIsIllegal<DecomposeAtenTOp>(patterns);
1291413012 addPatternIfTargetOpIsIllegal<DecomposeAten_LogSoftmaxBackwardDataOp>(
0 commit comments