Improve decomposition of pixel_shuffle (support dynamic shapes) (#2590)

The aten.reshape ops in the decomposition are replaced with prims.collapse 
and prims.split_dim ops, which means that the cases where the lowering of
reshape from torch to linalg which are not supported, are avoided.

Essentially, by using the collapse and split_dim ops instead of the
reshape ops, we are not "losing" the information that the reshapes do not
arbitrarily mix dimensions. Which makes lowering easy. 

3 additional tests added: 
- fully dynamic, 
- dynamic only the spatial dimensions, 
- dynamic only in the non-spatial dimensions.
pull/2595/head
James Newling 2023-11-22 12:31:06 -08:00 committed by GitHub
parent e06efc5136
commit 1b7d6f2af9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 159 additions and 111 deletions

View File

@ -1095,16 +1095,22 @@ public:
};
} // namespace
// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations.
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
// prims.collapse operations.
//
// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where
// leading_dims is of size N, then
// If input is a tensor of shape
// (*leading_dims, C*r*r, H, W),
//
// where leading_dims is of size N, then
// X = pixel_shuffle(input, upscale_factor)
//
// gets replaced with
// A = input.reshape(*leading_dims, C, r, r, H, W)
// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2)
// X = B.reshape(*leading_dims, C, r*H, r*W)
// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W)
// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W)
// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2)
// # shape (*leading_dims, C, H, r, W, r)
// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W)
// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W)
//
// 'r' above is referred to as the 'upscale factor' or just 'factor' below.
namespace {
@ -1115,7 +1121,6 @@ public:
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value inValue = op.getSelf();
auto inType = inValue.getType().cast<BaseTensorType>();
@ -1127,22 +1132,6 @@ public:
auto inShape = maybeSizes.value();
auto inRank = inShape.size();
// TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg
// directly. Pixel shuffle does a reshape that is hard to recover
// through pure torch (view) ops, especially in dynamic cases.
//
// See: https://github.com/llvm/torch-mlir/issues/2559
//
// For now, we just fail the decomposition here so that a sensible error is
// provided:
for (auto dimSize : inShape) {
if (dimSize == kUnknownSize) {
return rewriter.notifyMatchFailure(
op, "Currently we only decompose pixel_shuffle if the input tensor "
"is statically shaped");
}
}
// The input tensor must have at least 3 dimensions: (1) the channel
// dimension which gets smaller by 'factor*factor', (2) the H channel which
// gets larger by 'factor' and (3) the W channel which get larger by
@ -1152,58 +1141,6 @@ public:
return rewriter.notifyMatchFailure(
op, "Expected input tensor to have rank greater than 2.");
auto nLeadingDims = inRank - 3;
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};
auto inC = getDimSize(inRank - 3);
auto inH = getDimSize(inRank - 2);
auto inW = getDimSize(inRank - 1);
auto factor = op.getUpscaleFactor();
Value factorSquared =
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
Value outC =
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);
// Shape of 'A' in the comment at the top
SmallVector<Value> prePermuteShape;
prePermuteShape.reserve(nLeadingDims + 5);
// Shape of 'B' in the comment at the top.
SmallVector<Value> postPermuteShape;
postPermuteShape.reserve(nLeadingDims + 5);
SmallVector<Value> outShape;
outShape.reserve(nLeadingDims + 3);
SmallVector<Value> permutation;
permutation.reserve(nLeadingDims + 5);
for (unsigned i = 0; i < nLeadingDims; ++i) {
auto dimensionAttr = rewriter.getI64IntegerAttr(i);
Value dimensionValue = rewriter.create<ConstantIntOp>(loc, dimensionAttr);
Value leadingDimSize =
rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dimensionValue);
prePermuteShape.push_back(leadingDimSize);
postPermuteShape.push_back(leadingDimSize);
outShape.push_back(leadingDimSize);
permutation.push_back(dimensionValue);
}
const auto inOptionalDType = inType.getOptionalDtype();
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
@ -1227,48 +1164,111 @@ public:
llvm::ArrayRef(intShape), inOptionalDType);
};
prePermuteShape.insert(prePermuteShape.end(),
{outC, factor, factor, inH, inW});
auto nLeadingDims = inRank - 3;
postPermuteShape.insert(postPermuteShape.end(),
{outC, inH, factor, inW, factor});
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
// folded to a ConstantOp.
auto getDimSize = [&](uint64_t i) -> Value {
Value dim =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
};
outShape.insert(outShape.end(), {outC, outH, outW});
auto inC = getDimSize(inRank - 3);
auto inH = getDimSize(inRank - 2);
auto inW = getDimSize(inRank - 1);
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
for (uint64_t d : permutationTail) {
permutation.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(nLeadingDims + d)));
auto factor = op.getUpscaleFactor();
Value factorSquared =
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
Value outC =
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);
SmallVector<Value> dimensionConstants;
dimensionConstants.reserve(inRank + 2);
for (unsigned i = 0; i < inRank + 2; ++i) {
dimensionConstants.push_back(
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
}
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
SmallVector<Value> leadingDims;
leadingDims.reserve(nLeadingDims);
for (unsigned i = 0; i < nLeadingDims; ++i) {
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
loc, inValue, dimensionConstants[i]);
leadingDims.push_back(leadingDimSize);
}
Value shapeA =
rewriter.create<PrimListConstructOp>(loc, listType, prePermuteShape);
SmallVector<Value> partiallyExpandedShape = leadingDims;
partiallyExpandedShape.append({outC, factorSquared, inH, inW});
Value A = rewriter.create<AtenReshapeOp>(
loc, getTypeFromShape(prePermuteShape), inValue, shapeA);
SmallVector<Value> prePermuteShape = leadingDims;
prePermuteShape.append({outC, factor, factor, inH, inW});
SmallVector<Value> postPermuteShape = leadingDims;
postPermuteShape.append({outC, inH, factor, inW, factor});
SmallVector<Value> partiallyCollapsedShape = leadingDims;
partiallyCollapsedShape.append({outC, inH, factor, outW});
SmallVector<Value> outShape = leadingDims;
outShape.append({outC, outH, outW});
SmallVector<Value> permutation{dimensionConstants.begin(),
dimensionConstants.begin() + nLeadingDims};
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
for (uint64_t d : permutationTail) {
permutation.push_back(dimensionConstants[nLeadingDims + d]);
}
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
permutation);
Value B = rewriter.create<AtenPermuteOp>(
loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder);
// Split input channel inC -> (inC, factorSquared)
auto partiallyExpanded =
rewriter
.create<PrimsSplitDimOp>(
loc, getTypeFromShape(partiallyExpandedShape), inValue,
dimensionConstants[nLeadingDims], outC)
.getResult();
Value outShapeList =
rewriter.create<PrimListConstructOp>(loc, listType, outShape);
// Split new dimension factorSquared -> (factor, factor)
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
dimensionConstants[nLeadingDims + 1], factor);
// Perform the permutation
auto permuted =
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
fullyExpanded, permuteDimsOrder);
// Collapse final 2 dimension
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
dimensionConstants[nLeadingDims + 3],
dimensionConstants[nLeadingDims + 4]);
// Collapse back to original rank
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
op, op.getType(), partiallyCollapsed,
dimensionConstants[nLeadingDims + 1],
dimensionConstants[nLeadingDims + 2]);
rewriter.replaceOpWithNewOp<AtenReshapeOp>(op, op.getType(), B,
outShapeList);
return success();
}
};
} // namespace
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
static Value
getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) {
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
Value input) {
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
@ -1815,7 +1815,7 @@ public:
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
"Expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
@ -1851,7 +1851,7 @@ public:
Value dimSize =
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dimValue);
if (i == dimInt) {
int64_t inferredSizeInt = inputShape[i];
int64_t inferredSizeInt = inputShape[i];
int64_t inferredDim;
for (unsigned j = 0; j < sizesInts.size(); ++j) {
if (sizesInts[j] == -1) {
@ -1865,11 +1865,9 @@ public:
}
}
if (inferred) {
Value inferredSize =
rewriter.create<ConstantIntOp>(
Value inferredSize = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inferredSizeInt));
newSizes.insert(
newSizes.begin() + inferredDim + i, inferredSize);
newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize);
}
} else {
newSizes.push_back(dimSize);
@ -4095,7 +4093,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
} // namespace
namespace {
class DecomposeAtenCosineSimilarityOp
class DecomposeAtenCosineSimilarityOp
: public OpRewritePattern<AtenCosineSimilarityOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenCosineSimilarityOp op,
@ -4122,7 +4120,7 @@ class DecomposeAtenCosineSimilarityOp
indexBroadcastShapeTorchList);
// Compute the mul of A and B
Value dotProduct =
Value dotProduct =
rewriter.create<AtenMulTensorOp>(loc, broadcastType, x1, x2);
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
@ -4133,17 +4131,17 @@ class DecomposeAtenCosineSimilarityOp
loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList,
/*keepdim=*/cstFalse,
/*dtype=*/cstNone);
// Compute the norm of A and B
Value ord = rewriter.create<Torch::ConstantFloatOp>(loc,
rewriter.getF64FloatAttr(2.0));
Value ord = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(2.0));
Value normA = rewriter.create<AtenLinalgVectorNormOp>(
loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);
Value normB = rewriter.create<AtenLinalgVectorNormOp>(
loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse,
/*dtype=*/cstNone);
// Compute the product of the norms
Value normProduct =
rewriter.create<AtenMulTensorOp>(loc, op.getType(), normA, normB);

View File

@ -948,8 +948,6 @@ STABLEHLO_CRASHING_SET = {
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
@ -1371,6 +1369,9 @@ LTC_XFAIL_SET = {
"SplitDimDynamicModule_basic",
"PixelShuffleModuleStaticRank3Int64_basic",
"PixelShuffleModuleStaticRank4Float32_basic",
"PixelShuffleModuleFullDynamic_basic",
"PixelShuffleModuleSpatiallyDynamic_basic",
"PixelShuffleModuleSpatiallyStatic_basic",
"_Convolution2DAllFalseModule_basic",
"_Convolution2DBenchmarkModule_basic",
"_Convolution2DCudnnModule_basic",

View File

@ -686,8 +686,57 @@ class PixelShuffleModuleStaticRank3Int64(torch.nn.Module):
def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils):
module.forward(tu.randint(12, 2, 3, low = 0, high = 100))
# ==============================================================================
class PixelShuffleModuleFullDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 2)
@register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic())
def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils):
module.forward(tu.randint(1,8,3,3, low = 0, high = 100))
# ==============================================================================
class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 2)
@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic())
def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils):
module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100))
# ==============================================================================
class PixelShuffleModuleSpatiallyStatic(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)])
def forward(self, x):
return torch.ops.aten.pixel_shuffle(x, 2)
@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic())
def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils):
module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100))
# ==============================================================================
class TensorsConcatModule(torch.nn.Module):