From 1b7d6f2af91429c5a8e680faafaadd1b843cb0ca Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 22 Nov 2023 12:31:06 -0800 Subject: [PATCH] Improve decomposition of pixel_shuffle (support dynamic shapes) (#2590) The aten.reshape ops in the decomposition are replaced with prims.collapse and prims.split_dim ops, which means that the cases where the lowering of reshape from torch to linalg which are not supported, are avoided. Essentially, by using the collapse and split_dim ops instead of the reshape ops, we are not "losing" the information that the reshapes do not arbitrarily mix dimensions. Which makes lowering easy. 3 additional tests added: - fully dynamic, - dynamic only the spatial dimensions, - dynamic only in the non-spatial dimensions. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 216 +++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 5 +- .../torch_mlir_e2e_test/test_suite/basic.py | 49 ++++ 3 files changed, 159 insertions(+), 111 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b19d3f949..a04e4da13 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1095,16 +1095,22 @@ public: }; } // namespace -// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations. +// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and +// prims.collapse operations. // -// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where -// leading_dims is of size N, then +// If input is a tensor of shape +// (*leading_dims, C*r*r, H, W), +// +// where leading_dims is of size N, then // X = pixel_shuffle(input, upscale_factor) // // gets replaced with -// A = input.reshape(*leading_dims, C, r, r, H, W) -// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2) -// X = B.reshape(*leading_dims, C, r*H, r*W) +// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W) +// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W) +// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2) +// # shape (*leading_dims, C, H, r, W, r) +// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W) +// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W) // // 'r' above is referred to as the 'upscale factor' or just 'factor' below. namespace { @@ -1115,7 +1121,6 @@ public: LogicalResult matchAndRewrite(AtenPixelShuffleOp op, PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); Value inValue = op.getSelf(); auto inType = inValue.getType().cast(); @@ -1127,22 +1132,6 @@ public: auto inShape = maybeSizes.value(); auto inRank = inShape.size(); - // TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg - // directly. Pixel shuffle does a reshape that is hard to recover - // through pure torch (view) ops, especially in dynamic cases. - // - // See: https://github.com/llvm/torch-mlir/issues/2559 - // - // For now, we just fail the decomposition here so that a sensible error is - // provided: - for (auto dimSize : inShape) { - if (dimSize == kUnknownSize) { - return rewriter.notifyMatchFailure( - op, "Currently we only decompose pixel_shuffle if the input tensor " - "is statically shaped"); - } - } - // The input tensor must have at least 3 dimensions: (1) the channel // dimension which gets smaller by 'factor*factor', (2) the H channel which // gets larger by 'factor' and (3) the W channel which get larger by @@ -1152,58 +1141,6 @@ public: return rewriter.notifyMatchFailure( op, "Expected input tensor to have rank greater than 2."); - auto nLeadingDims = inRank - 3; - - // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead - // of 'create': if the dimension size is known, then the AtenSizeIntOp is - // folded to a ConstantOp. - auto getDimSize = [&](uint64_t i) -> Value { - Value dim = - rewriter.create(loc, rewriter.getI64IntegerAttr(i)); - return rewriter.createOrFold(loc, inValue, dim); - }; - - auto inC = getDimSize(inRank - 3); - auto inH = getDimSize(inRank - 2); - auto inW = getDimSize(inRank - 1); - - auto factor = op.getUpscaleFactor(); - - - Value factorSquared = - rewriter.createOrFold(loc, factor, factor); - Value outC = - rewriter.createOrFold(loc, inC, factorSquared); - - Value outH = rewriter.createOrFold(loc, inH, factor); - Value outW = rewriter.createOrFold(loc, inW, factor); - - // Shape of 'A' in the comment at the top - SmallVector prePermuteShape; - prePermuteShape.reserve(nLeadingDims + 5); - - // Shape of 'B' in the comment at the top. - SmallVector postPermuteShape; - postPermuteShape.reserve(nLeadingDims + 5); - - SmallVector outShape; - outShape.reserve(nLeadingDims + 3); - - SmallVector permutation; - permutation.reserve(nLeadingDims + 5); - - for (unsigned i = 0; i < nLeadingDims; ++i) { - auto dimensionAttr = rewriter.getI64IntegerAttr(i); - Value dimensionValue = rewriter.create(loc, dimensionAttr); - Value leadingDimSize = - rewriter.createOrFold(loc, inValue, dimensionValue); - prePermuteShape.push_back(leadingDimSize); - postPermuteShape.push_back(leadingDimSize); - outShape.push_back(leadingDimSize); - permutation.push_back(dimensionValue); - - } - const auto inOptionalDType = inType.getOptionalDtype(); auto getTypeFromShape = [inOptionalDType](auto &&vals) { @@ -1227,48 +1164,111 @@ public: llvm::ArrayRef(intShape), inOptionalDType); }; - prePermuteShape.insert(prePermuteShape.end(), - {outC, factor, factor, inH, inW}); + auto nLeadingDims = inRank - 3; - postPermuteShape.insert(postPermuteShape.end(), - {outC, inH, factor, inW, factor}); + // Get the size of the dimension 'i'. Note the use of 'createOrFold' instead + // of 'create': if the dimension size is known, then the AtenSizeIntOp is + // folded to a ConstantOp. + auto getDimSize = [&](uint64_t i) -> Value { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + return rewriter.createOrFold(loc, inValue, dim); + }; - outShape.insert(outShape.end(), {outC, outH, outW}); + auto inC = getDimSize(inRank - 3); + auto inH = getDimSize(inRank - 2); + auto inW = getDimSize(inRank - 1); - SmallVector permutationTail{0, 3, 1, 4, 2}; - for (uint64_t d : permutationTail) { - permutation.push_back(rewriter.create( - loc, rewriter.getI64IntegerAttr(nLeadingDims + d))); + auto factor = op.getUpscaleFactor(); + + Value factorSquared = + rewriter.createOrFold(loc, factor, factor); + + Value outC = + rewriter.createOrFold(loc, inC, factorSquared); + + Value outH = rewriter.createOrFold(loc, inH, factor); + Value outW = rewriter.createOrFold(loc, inW, factor); + + SmallVector dimensionConstants; + dimensionConstants.reserve(inRank + 2); + for (unsigned i = 0; i < inRank + 2; ++i) { + dimensionConstants.push_back( + rewriter.create(loc, rewriter.getI64IntegerAttr(i))); } - auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); + SmallVector leadingDims; + leadingDims.reserve(nLeadingDims); + for (unsigned i = 0; i < nLeadingDims; ++i) { + Value leadingDimSize = rewriter.createOrFold( + loc, inValue, dimensionConstants[i]); + leadingDims.push_back(leadingDimSize); + } - Value shapeA = - rewriter.create(loc, listType, prePermuteShape); + SmallVector partiallyExpandedShape = leadingDims; + partiallyExpandedShape.append({outC, factorSquared, inH, inW}); - Value A = rewriter.create( - loc, getTypeFromShape(prePermuteShape), inValue, shapeA); + SmallVector prePermuteShape = leadingDims; + prePermuteShape.append({outC, factor, factor, inH, inW}); + + SmallVector postPermuteShape = leadingDims; + postPermuteShape.append({outC, inH, factor, inW, factor}); + + SmallVector partiallyCollapsedShape = leadingDims; + partiallyCollapsedShape.append({outC, inH, factor, outW}); + + SmallVector outShape = leadingDims; + outShape.append({outC, outH, outW}); + + SmallVector permutation{dimensionConstants.begin(), + dimensionConstants.begin() + nLeadingDims}; + SmallVector permutationTail{0, 3, 1, 4, 2}; + for (uint64_t d : permutationTail) { + permutation.push_back(dimensionConstants[nLeadingDims + d]); + } Value permuteDimsOrder = rewriter.create( loc, Torch::ListType::get(Torch::IntType::get(op->getContext())), permutation); - Value B = rewriter.create( - loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder); + // Split input channel inC -> (inC, factorSquared) + auto partiallyExpanded = + rewriter + .create( + loc, getTypeFromShape(partiallyExpandedShape), inValue, + dimensionConstants[nLeadingDims], outC) + .getResult(); - Value outShapeList = - rewriter.create(loc, listType, outShape); + // Split new dimension factorSquared -> (factor, factor) + auto fullyExpanded = rewriter.create( + loc, getTypeFromShape(prePermuteShape), partiallyExpanded, + dimensionConstants[nLeadingDims + 1], factor); + + // Perform the permutation + auto permuted = + rewriter.create(loc, getTypeFromShape(postPermuteShape), + fullyExpanded, permuteDimsOrder); + + // Collapse final 2 dimension + auto partiallyCollapsed = rewriter.create( + loc, getTypeFromShape(partiallyCollapsedShape), permuted, + dimensionConstants[nLeadingDims + 3], + dimensionConstants[nLeadingDims + 4]); + + // Collapse back to original rank + rewriter.replaceOpWithNewOp( + op, op.getType(), partiallyCollapsed, + dimensionConstants[nLeadingDims + 1], + dimensionConstants[nLeadingDims + 2]); - rewriter.replaceOpWithNewOp(op, op.getType(), B, - outShapeList); return success(); } }; } // namespace // ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6) -static Value -getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) { +static Value getRelu6Results(PatternRewriter &rewriter, Location loc, + Value input) { BaseTensorType inputType = input.getType().cast(); Value relu = rewriter.create(loc, inputType, input); @@ -1815,7 +1815,7 @@ public: auto inputTensorType = self.getType().cast(); if (!inputTensorType || !inputTensorType.hasSizes()) { return rewriter.notifyMatchFailure(op, - "Expected input type having sizes"); + "Expected input type having sizes"); } ArrayRef inputShape = inputTensorType.getSizes(); @@ -1851,7 +1851,7 @@ public: Value dimSize = rewriter.create(loc, self, /*dim=*/dimValue); if (i == dimInt) { - int64_t inferredSizeInt = inputShape[i]; + int64_t inferredSizeInt = inputShape[i]; int64_t inferredDim; for (unsigned j = 0; j < sizesInts.size(); ++j) { if (sizesInts[j] == -1) { @@ -1865,11 +1865,9 @@ public: } } if (inferred) { - Value inferredSize = - rewriter.create( + Value inferredSize = rewriter.create( loc, rewriter.getI64IntegerAttr(inferredSizeInt)); - newSizes.insert( - newSizes.begin() + inferredDim + i, inferredSize); + newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize); } } else { newSizes.push_back(dimSize); @@ -4095,7 +4093,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern { } // namespace namespace { -class DecomposeAtenCosineSimilarityOp +class DecomposeAtenCosineSimilarityOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AtenCosineSimilarityOp op, @@ -4122,7 +4120,7 @@ class DecomposeAtenCosineSimilarityOp indexBroadcastShapeTorchList); // Compute the mul of A and B - Value dotProduct = + Value dotProduct = rewriter.create(loc, broadcastType, x1, x2); Value cstFalse = rewriter.create(loc, false); Value cstNone = rewriter.create(loc); @@ -4133,17 +4131,17 @@ class DecomposeAtenCosineSimilarityOp loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the norm of A and B - Value ord = rewriter.create(loc, - rewriter.getF64FloatAttr(2.0)); + Value ord = rewriter.create( + loc, rewriter.getF64FloatAttr(2.0)); Value normA = rewriter.create( loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); Value normB = rewriter.create( loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse, /*dtype=*/cstNone); - + // Compute the product of the norms Value normProduct = rewriter.create(loc, op.getType(), normA, normB); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 595afbb74..eeaad8690 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -948,8 +948,6 @@ STABLEHLO_CRASHING_SET = { # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { - "PixelShuffleModuleStaticRank3Int64_basic", - "PixelShuffleModuleStaticRank4Float32_basic", "IscloseStaticModule_basic", "IscloseStaticModuleTrue_basic", "TileBigDimsSizeModule_basic", @@ -1371,6 +1369,9 @@ LTC_XFAIL_SET = { "SplitDimDynamicModule_basic", "PixelShuffleModuleStaticRank3Int64_basic", "PixelShuffleModuleStaticRank4Float32_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 9eb1a8986..a9d9e2e9c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -686,8 +686,57 @@ class PixelShuffleModuleStaticRank3Int64(torch.nn.Module): def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils): module.forward(tu.randint(12, 2, 3, low = 0, high = 100)) +# ============================================================================== +class PixelShuffleModuleFullDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic()) +def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,8,3,3, low = 0, high = 100)) + +# ============================================================================== + + +class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic()) +def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils): + module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100)) + + +# ============================================================================== + +class PixelShuffleModuleSpatiallyStatic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)]) + def forward(self, x): + return torch.ops.aten.pixel_shuffle(x, 2) + +@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic()) +def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils): + module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100)) + + +# ============================================================================== + class TensorsConcatModule(torch.nn.Module):