mirror of https://github.com/llvm/torch-mlir
Improve decomposition of pixel_shuffle (support dynamic shapes) (#2590)
The aten.reshape ops in the decomposition are replaced with prims.collapse and prims.split_dim ops, which means that the cases where the lowering of reshape from torch to linalg which are not supported, are avoided. Essentially, by using the collapse and split_dim ops instead of the reshape ops, we are not "losing" the information that the reshapes do not arbitrarily mix dimensions. Which makes lowering easy. 3 additional tests added: - fully dynamic, - dynamic only the spatial dimensions, - dynamic only in the non-spatial dimensions.pull/2595/head
parent
e06efc5136
commit
1b7d6f2af9
|
@ -1095,16 +1095,22 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
// Decompose aten.pixel_shuffle into: aten.permute and aten.reshape operations.
|
||||
// Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and
|
||||
// prims.collapse operations.
|
||||
//
|
||||
// If input is a tensor of shape (*leading_dims, C*r*r, H, W), where
|
||||
// leading_dims is of size N, then
|
||||
// If input is a tensor of shape
|
||||
// (*leading_dims, C*r*r, H, W),
|
||||
//
|
||||
// where leading_dims is of size N, then
|
||||
// X = pixel_shuffle(input, upscale_factor)
|
||||
//
|
||||
// gets replaced with
|
||||
// A = input.reshape(*leading_dims, C, r, r, H, W)
|
||||
// B = A.permute(0, ..., N, N+3, N+1, N+4, N+2)
|
||||
// X = B.reshape(*leading_dims, C, r*H, r*W)
|
||||
// X = input.split_dim(...) # shape (*leading_dims, C, r*r, H, W)
|
||||
// X = X.split_dim(...) # shape (*leading_dims, C, r, r, H, W)
|
||||
// X = X.permute(0, ..., N, N+3, N+1, N+4, N+2)
|
||||
// # shape (*leading_dims, C, H, r, W, r)
|
||||
// X = X.collapse(...) # shape (*leading_dims, C, r, H, r*W)
|
||||
// X = X.collapse(...) # shape (*leading_dims, C, r*H, r*W)
|
||||
//
|
||||
// 'r' above is referred to as the 'upscale factor' or just 'factor' below.
|
||||
namespace {
|
||||
|
@ -1115,7 +1121,6 @@ public:
|
|||
LogicalResult matchAndRewrite(AtenPixelShuffleOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value inValue = op.getSelf();
|
||||
auto inType = inValue.getType().cast<BaseTensorType>();
|
||||
|
@ -1127,22 +1132,6 @@ public:
|
|||
auto inShape = maybeSizes.value();
|
||||
auto inRank = inShape.size();
|
||||
|
||||
// TODO support dynamic shapes, probably by lowering pixel_shuffle to linalg
|
||||
// directly. Pixel shuffle does a reshape that is hard to recover
|
||||
// through pure torch (view) ops, especially in dynamic cases.
|
||||
//
|
||||
// See: https://github.com/llvm/torch-mlir/issues/2559
|
||||
//
|
||||
// For now, we just fail the decomposition here so that a sensible error is
|
||||
// provided:
|
||||
for (auto dimSize : inShape) {
|
||||
if (dimSize == kUnknownSize) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Currently we only decompose pixel_shuffle if the input tensor "
|
||||
"is statically shaped");
|
||||
}
|
||||
}
|
||||
|
||||
// The input tensor must have at least 3 dimensions: (1) the channel
|
||||
// dimension which gets smaller by 'factor*factor', (2) the H channel which
|
||||
// gets larger by 'factor' and (3) the W channel which get larger by
|
||||
|
@ -1152,58 +1141,6 @@ public:
|
|||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected input tensor to have rank greater than 2.");
|
||||
|
||||
auto nLeadingDims = inRank - 3;
|
||||
|
||||
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
|
||||
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
|
||||
// folded to a ConstantOp.
|
||||
auto getDimSize = [&](uint64_t i) -> Value {
|
||||
Value dim =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
||||
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
|
||||
};
|
||||
|
||||
auto inC = getDimSize(inRank - 3);
|
||||
auto inH = getDimSize(inRank - 2);
|
||||
auto inW = getDimSize(inRank - 1);
|
||||
|
||||
auto factor = op.getUpscaleFactor();
|
||||
|
||||
|
||||
Value factorSquared =
|
||||
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
|
||||
Value outC =
|
||||
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);
|
||||
|
||||
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
|
||||
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);
|
||||
|
||||
// Shape of 'A' in the comment at the top
|
||||
SmallVector<Value> prePermuteShape;
|
||||
prePermuteShape.reserve(nLeadingDims + 5);
|
||||
|
||||
// Shape of 'B' in the comment at the top.
|
||||
SmallVector<Value> postPermuteShape;
|
||||
postPermuteShape.reserve(nLeadingDims + 5);
|
||||
|
||||
SmallVector<Value> outShape;
|
||||
outShape.reserve(nLeadingDims + 3);
|
||||
|
||||
SmallVector<Value> permutation;
|
||||
permutation.reserve(nLeadingDims + 5);
|
||||
|
||||
for (unsigned i = 0; i < nLeadingDims; ++i) {
|
||||
auto dimensionAttr = rewriter.getI64IntegerAttr(i);
|
||||
Value dimensionValue = rewriter.create<ConstantIntOp>(loc, dimensionAttr);
|
||||
Value leadingDimSize =
|
||||
rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dimensionValue);
|
||||
prePermuteShape.push_back(leadingDimSize);
|
||||
postPermuteShape.push_back(leadingDimSize);
|
||||
outShape.push_back(leadingDimSize);
|
||||
permutation.push_back(dimensionValue);
|
||||
|
||||
}
|
||||
|
||||
const auto inOptionalDType = inType.getOptionalDtype();
|
||||
|
||||
auto getTypeFromShape = [inOptionalDType](auto &&vals) {
|
||||
|
@ -1227,48 +1164,111 @@ public:
|
|||
llvm::ArrayRef(intShape), inOptionalDType);
|
||||
};
|
||||
|
||||
prePermuteShape.insert(prePermuteShape.end(),
|
||||
{outC, factor, factor, inH, inW});
|
||||
auto nLeadingDims = inRank - 3;
|
||||
|
||||
postPermuteShape.insert(postPermuteShape.end(),
|
||||
{outC, inH, factor, inW, factor});
|
||||
// Get the size of the dimension 'i'. Note the use of 'createOrFold' instead
|
||||
// of 'create': if the dimension size is known, then the AtenSizeIntOp is
|
||||
// folded to a ConstantOp.
|
||||
auto getDimSize = [&](uint64_t i) -> Value {
|
||||
Value dim =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
||||
return rewriter.createOrFold<AtenSizeIntOp>(loc, inValue, dim);
|
||||
};
|
||||
|
||||
outShape.insert(outShape.end(), {outC, outH, outW});
|
||||
auto inC = getDimSize(inRank - 3);
|
||||
auto inH = getDimSize(inRank - 2);
|
||||
auto inW = getDimSize(inRank - 1);
|
||||
|
||||
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
|
||||
for (uint64_t d : permutationTail) {
|
||||
permutation.push_back(rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(nLeadingDims + d)));
|
||||
auto factor = op.getUpscaleFactor();
|
||||
|
||||
Value factorSquared =
|
||||
rewriter.createOrFold<AtenMulIntOp>(loc, factor, factor);
|
||||
|
||||
Value outC =
|
||||
rewriter.createOrFold<AtenFloordivIntOp>(loc, inC, factorSquared);
|
||||
|
||||
Value outH = rewriter.createOrFold<AtenMulIntOp>(loc, inH, factor);
|
||||
Value outW = rewriter.createOrFold<AtenMulIntOp>(loc, inW, factor);
|
||||
|
||||
SmallVector<Value> dimensionConstants;
|
||||
dimensionConstants.reserve(inRank + 2);
|
||||
for (unsigned i = 0; i < inRank + 2; ++i) {
|
||||
dimensionConstants.push_back(
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i)));
|
||||
}
|
||||
|
||||
auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext()));
|
||||
SmallVector<Value> leadingDims;
|
||||
leadingDims.reserve(nLeadingDims);
|
||||
for (unsigned i = 0; i < nLeadingDims; ++i) {
|
||||
Value leadingDimSize = rewriter.createOrFold<AtenSizeIntOp>(
|
||||
loc, inValue, dimensionConstants[i]);
|
||||
leadingDims.push_back(leadingDimSize);
|
||||
}
|
||||
|
||||
Value shapeA =
|
||||
rewriter.create<PrimListConstructOp>(loc, listType, prePermuteShape);
|
||||
SmallVector<Value> partiallyExpandedShape = leadingDims;
|
||||
partiallyExpandedShape.append({outC, factorSquared, inH, inW});
|
||||
|
||||
Value A = rewriter.create<AtenReshapeOp>(
|
||||
loc, getTypeFromShape(prePermuteShape), inValue, shapeA);
|
||||
SmallVector<Value> prePermuteShape = leadingDims;
|
||||
prePermuteShape.append({outC, factor, factor, inH, inW});
|
||||
|
||||
SmallVector<Value> postPermuteShape = leadingDims;
|
||||
postPermuteShape.append({outC, inH, factor, inW, factor});
|
||||
|
||||
SmallVector<Value> partiallyCollapsedShape = leadingDims;
|
||||
partiallyCollapsedShape.append({outC, inH, factor, outW});
|
||||
|
||||
SmallVector<Value> outShape = leadingDims;
|
||||
outShape.append({outC, outH, outW});
|
||||
|
||||
SmallVector<Value> permutation{dimensionConstants.begin(),
|
||||
dimensionConstants.begin() + nLeadingDims};
|
||||
SmallVector<uint64_t> permutationTail{0, 3, 1, 4, 2};
|
||||
for (uint64_t d : permutationTail) {
|
||||
permutation.push_back(dimensionConstants[nLeadingDims + d]);
|
||||
}
|
||||
|
||||
Value permuteDimsOrder = rewriter.create<PrimListConstructOp>(
|
||||
loc, Torch::ListType::get(Torch::IntType::get(op->getContext())),
|
||||
permutation);
|
||||
|
||||
Value B = rewriter.create<AtenPermuteOp>(
|
||||
loc, getTypeFromShape(postPermuteShape), A, permuteDimsOrder);
|
||||
// Split input channel inC -> (inC, factorSquared)
|
||||
auto partiallyExpanded =
|
||||
rewriter
|
||||
.create<PrimsSplitDimOp>(
|
||||
loc, getTypeFromShape(partiallyExpandedShape), inValue,
|
||||
dimensionConstants[nLeadingDims], outC)
|
||||
.getResult();
|
||||
|
||||
Value outShapeList =
|
||||
rewriter.create<PrimListConstructOp>(loc, listType, outShape);
|
||||
// Split new dimension factorSquared -> (factor, factor)
|
||||
auto fullyExpanded = rewriter.create<PrimsSplitDimOp>(
|
||||
loc, getTypeFromShape(prePermuteShape), partiallyExpanded,
|
||||
dimensionConstants[nLeadingDims + 1], factor);
|
||||
|
||||
// Perform the permutation
|
||||
auto permuted =
|
||||
rewriter.create<AtenPermuteOp>(loc, getTypeFromShape(postPermuteShape),
|
||||
fullyExpanded, permuteDimsOrder);
|
||||
|
||||
// Collapse final 2 dimension
|
||||
auto partiallyCollapsed = rewriter.create<PrimsCollapseOp>(
|
||||
loc, getTypeFromShape(partiallyCollapsedShape), permuted,
|
||||
dimensionConstants[nLeadingDims + 3],
|
||||
dimensionConstants[nLeadingDims + 4]);
|
||||
|
||||
// Collapse back to original rank
|
||||
rewriter.replaceOpWithNewOp<PrimsCollapseOp>(
|
||||
op, op.getType(), partiallyCollapsed,
|
||||
dimensionConstants[nLeadingDims + 1],
|
||||
dimensionConstants[nLeadingDims + 2]);
|
||||
|
||||
rewriter.replaceOpWithNewOp<AtenReshapeOp>(op, op.getType(), B,
|
||||
outShapeList);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// ReLU6(x) = min(max(0, x), 6) = min(Relu(x), 6)
|
||||
static Value
|
||||
getRelu6Results(PatternRewriter &rewriter, Location loc, Value input) {
|
||||
static Value getRelu6Results(PatternRewriter &rewriter, Location loc,
|
||||
Value input) {
|
||||
BaseTensorType inputType = input.getType().cast<BaseTensorType>();
|
||||
|
||||
Value relu = rewriter.create<AtenReluOp>(loc, inputType, input);
|
||||
|
@ -1815,7 +1815,7 @@ public:
|
|||
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
|
||||
if (!inputTensorType || !inputTensorType.hasSizes()) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Expected input type having sizes");
|
||||
"Expected input type having sizes");
|
||||
}
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
|
||||
|
@ -1851,7 +1851,7 @@ public:
|
|||
Value dimSize =
|
||||
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dimValue);
|
||||
if (i == dimInt) {
|
||||
int64_t inferredSizeInt = inputShape[i];
|
||||
int64_t inferredSizeInt = inputShape[i];
|
||||
int64_t inferredDim;
|
||||
for (unsigned j = 0; j < sizesInts.size(); ++j) {
|
||||
if (sizesInts[j] == -1) {
|
||||
|
@ -1865,11 +1865,9 @@ public:
|
|||
}
|
||||
}
|
||||
if (inferred) {
|
||||
Value inferredSize =
|
||||
rewriter.create<ConstantIntOp>(
|
||||
Value inferredSize = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(inferredSizeInt));
|
||||
newSizes.insert(
|
||||
newSizes.begin() + inferredDim + i, inferredSize);
|
||||
newSizes.insert(newSizes.begin() + inferredDim + i, inferredSize);
|
||||
}
|
||||
} else {
|
||||
newSizes.push_back(dimSize);
|
||||
|
@ -4095,7 +4093,7 @@ class DecomposeAtenClampMaxOp : public OpRewritePattern<AtenClampMaxOp> {
|
|||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenCosineSimilarityOp
|
||||
class DecomposeAtenCosineSimilarityOp
|
||||
: public OpRewritePattern<AtenCosineSimilarityOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenCosineSimilarityOp op,
|
||||
|
@ -4122,7 +4120,7 @@ class DecomposeAtenCosineSimilarityOp
|
|||
indexBroadcastShapeTorchList);
|
||||
|
||||
// Compute the mul of A and B
|
||||
Value dotProduct =
|
||||
Value dotProduct =
|
||||
rewriter.create<AtenMulTensorOp>(loc, broadcastType, x1, x2);
|
||||
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
|
||||
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(loc);
|
||||
|
@ -4133,17 +4131,17 @@ class DecomposeAtenCosineSimilarityOp
|
|||
loc, op.getType(), /*self=*/dotProduct, /*dim=*/dimList,
|
||||
/*keepdim=*/cstFalse,
|
||||
/*dtype=*/cstNone);
|
||||
|
||||
|
||||
// Compute the norm of A and B
|
||||
Value ord = rewriter.create<Torch::ConstantFloatOp>(loc,
|
||||
rewriter.getF64FloatAttr(2.0));
|
||||
Value ord = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(2.0));
|
||||
Value normA = rewriter.create<AtenLinalgVectorNormOp>(
|
||||
loc, op.getType(), x1, ord, dimList, /*keepdim=*/cstFalse,
|
||||
/*dtype=*/cstNone);
|
||||
Value normB = rewriter.create<AtenLinalgVectorNormOp>(
|
||||
loc, op.getType(), x2, ord, dimList, /*keepdim=*/cstFalse,
|
||||
/*dtype=*/cstNone);
|
||||
|
||||
|
||||
// Compute the product of the norms
|
||||
Value normProduct =
|
||||
rewriter.create<AtenMulTensorOp>(loc, op.getType(), normA, normB);
|
||||
|
|
|
@ -948,8 +948,6 @@ STABLEHLO_CRASHING_SET = {
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"TileBigDimsSizeModule_basic",
|
||||
|
@ -1371,6 +1369,9 @@ LTC_XFAIL_SET = {
|
|||
"SplitDimDynamicModule_basic",
|
||||
"PixelShuffleModuleStaticRank3Int64_basic",
|
||||
"PixelShuffleModuleStaticRank4Float32_basic",
|
||||
"PixelShuffleModuleFullDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyDynamic_basic",
|
||||
"PixelShuffleModuleSpatiallyStatic_basic",
|
||||
"_Convolution2DAllFalseModule_basic",
|
||||
"_Convolution2DBenchmarkModule_basic",
|
||||
"_Convolution2DCudnnModule_basic",
|
||||
|
|
|
@ -686,8 +686,57 @@ class PixelShuffleModuleStaticRank3Int64(torch.nn.Module):
|
|||
def PixelShuffleModuleStaticRank3Int64_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(12, 2, 3, low = 0, high = 100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class PixelShuffleModuleFullDynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1,-1,-1,-1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.pixel_shuffle(x, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: PixelShuffleModuleFullDynamic())
|
||||
def PixelShuffleModuleFullDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1,8,3,3, low = 0, high = 100))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class PixelShuffleModuleSpatiallyDynamic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([2,1,8,-1,-1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.pixel_shuffle(x, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyDynamic())
|
||||
def PixelShuffleModuleSpatiallyDynamic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(2,1,8,2,3, low = 0, high = 100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class PixelShuffleModuleSpatiallyStatic(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([None, ([-1,-1,-1,3,1], torch.int64, True)])
|
||||
def forward(self, x):
|
||||
return torch.ops.aten.pixel_shuffle(x, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: PixelShuffleModuleSpatiallyStatic())
|
||||
def PixelShuffleModuleSpatiallyStatic_basic(module, tu: TestUtils):
|
||||
module.forward(tu.randint(1,2,12,3,1, low = 0, high = 100))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorsConcatModule(torch.nn.Module):
|
||||
|
||||
|
|
Loading…
Reference in New Issue