[Torch Dialect] Add Support for aten.unflatten.int (#2475)

As title, Add support for aten.unflatten.int, support dim to be negative
and one of the sizes' elements to be -1
pull/2539/head snapshot-20231031.1008
JianzheXiao 2023-10-31 00:36:16 -07:00 committed by GitHub
parent b88f9ec8f2
commit e8706957c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 186 additions and 9 deletions

View File

@ -17,7 +17,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR failed
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}
@ -705,6 +704,10 @@ STABLEHLO_PASS_SET = {
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"UnflattenStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"RsubFloatModule_basic",
"RsubFloatModule_noalpha_basic",
"RsubIntModule_basic",
@ -994,6 +997,9 @@ TOSA_PASS_SET = {
"AtenToDeviceModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeDimModule_unitDim",

View File

@ -7240,13 +7240,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unflatten.int\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.list<int>) -> !torch.list<int> {\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.slice.t %arg0, %none, %arg1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %1 = torch.aten.add.t %0, %arg2 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %2 = torch.aten.add.int %arg1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %3 = torch.aten.slice.t %arg0, %2, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %4 = torch.aten.add.t %1, %3 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n"
" %0 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.prim.If %0 -> (!torch.int) {\n"
" %10 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %11 = torch.aten.add.int %arg1, %10 : !torch.int, !torch.int -> !torch.int\n"
" torch.prim.If.yield %11 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %arg1 : !torch.int\n"
" }\n"
" %2 = torch.aten.__getitem__.t %arg0, %1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.prim.ListConstruct %2 : (!torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch.jit._shape_functions.view(%3, %arg2) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" %5 = torch.aten.slice.t %arg0, %none, %1, %int1 : !torch.list<int>, !torch.none, !torch.int, !torch.int -> !torch.list<int>\n"
" %6 = torch.aten.add.t %5, %4 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" %7 = torch.aten.add.int %1, %int1 : !torch.int, !torch.int -> !torch.int\n"
" %8 = torch.aten.slice.t %arg0, %7, %none, %int1 : !torch.list<int>, !torch.int, !torch.none, !torch.int -> !torch.list<int>\n"
" %9 = torch.aten.add.t %6, %8 : !torch.list<int>, !torch.list<int> -> !torch.list<int>\n"
" return %9 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.linear\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.linear(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>) -> !torch.list<int>\n"

View File

@ -1494,6 +1494,100 @@ public:
};
} // namespace
// Decompose aten.unflatten.int into aten.view op.
namespace {
class DecomposeAtenUnflattenIntOp
: public OpRewritePattern<AtenUnflattenIntOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenUnflattenIntOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value self = op.getSelf();
MLIRContext *context = op.getContext();
BaseTensorType outputTensorType = op.getType().cast<BaseTensorType>();
if (!outputTensorType.hasSizes())
return rewriter.notifyMatchFailure(
op, "unimplemented: output must have known sizes");
std::optional<unsigned> maybeRank = getTensorRank(self);
if (!maybeRank)
return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor");
unsigned inputRank = *maybeRank;
auto inputTensorType = self.getType().cast<Torch::ValueTensorType>();
if (!inputTensorType || !inputTensorType.hasSizes()) {
return rewriter.notifyMatchFailure(op,
"Expected input type having sizes");
}
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
SmallVector<int64_t> sizesInts;
if (!matchPattern(op.getSizes(), m_TorchListOfConstantInts(sizesInts)))
return rewriter.notifyMatchFailure(
op, "sizes must be a list of constant ints");
bool inferred = false;
if (llvm::count(sizesInts, -1) > 1)
return rewriter.notifyMatchFailure(
op, "only one of sizes' elements can be -1");
int64_t dimInt;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)))
return rewriter.notifyMatchFailure(
op, "unimplemented: requires dim to be constants");
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
SmallVector<Value> sizesTorchInt;
if (!getListConstructElements(op.getSizes(), sizesTorchInt))
return rewriter.notifyMatchFailure(
op, "Unimplemented: sizes not list of Scalar");
// Create new sizes based on the unflattened dimension.
SmallVector<Value> newSizes;
for (int64_t i = 0; i < inputRank; ++i) {
Value dimValue =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value dimSize =
rewriter.create<AtenSizeIntOp>(loc, self, /*dim=*/dimValue);
if (i == dimInt) {
int64_t inferredSizeInt = inputShape[i];
int64_t inferredDim;
for (unsigned j = 0; j < sizesInts.size(); ++j) {
if (sizesInts[j] == -1) {
inferred = true;
inferredDim = j;
} else {
Value sizeValue = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sizesInts[j]));
newSizes.push_back(sizeValue);
inferredSizeInt = inferredSizeInt / sizesInts[j];
}
}
if (inferred) {
Value inferredSize =
rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inferredSizeInt));
newSizes.insert(
newSizes.begin() + inferredDim + i, inferredSize);
}
} else {
newSizes.push_back(dimSize);
}
}
// Create the AtenViewOp to replace the original op.
Value newSizeList = rewriter.create<PrimListConstructOp>(
loc, ListType::get(IntType::get(context)), newSizes);
rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), op.getSelf(),
newSizeList);
return success();
}
};
} // namespace
// Decompose aten.expand into aten.broadcast_to op.
namespace {
class DecomposeAtenExpandOp : public OpRewritePattern<AtenExpandOp> {
@ -5237,6 +5331,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenRepeatOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenExpandOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenFlattenUsingIntsOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenUnflattenIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarOtherOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenWhereScalarSelfOp>(patterns);

View File

@ -376,6 +376,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenRepeatOp>();
target.addIllegalOp<AtenExpandOp>();
target.addIllegalOp<AtenFlattenUsingIntsOp>();
target.addIllegalOp<AtenUnflattenIntOp>();
target.addIllegalOp<AtenWhereScalarOp>();
target.addIllegalOp<AtenWhereScalarOtherOp>();
target.addIllegalOp<AtenWhereScalarSelfOp>();

View File

@ -629,8 +629,18 @@ def atenadaptive_avg_pool2d〡shape(self: List[int], output_size: List[int])
def atenflattenusing_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]:
return upstream_shape_functions.flatten(self, start_dim, end_dim)
@check_shape_function([
Invocation(TensorOfShape(3, 6, 8), 1, [3, 2]),
Invocation(TensorOfShape(3, 6, 8), 1, [3, -1]), # contain one -1 in sizes
Invocation(TensorOfShape(3, 6, 8), -1, [2, -1, 2]), # dim = -1
])
def atenunflattenint〡shape(self: List[int], dim: int, sizes: List[int]) -> List[int]:
return self[:dim] + sizes + self[dim + 1:]
if dim < 0:
dim += len(self)
unflatten_shape: List[int] = [self[dim]]
unflatten_shape_output = upstream_shape_functions.view(unflatten_shape, sizes)
shape: List[int] = []
return self[:dim] + unflatten_shape_output + self[dim+1:]
def atenlinear〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None) -> List[int]:
return upstream_shape_functions.linear(input, weight, bias)
@ -1668,7 +1678,7 @@ def atenflattenusing_ints〡dtype(self_rank_dtype: Tuple[int, int], start_
self_rank, self_dtype = self_rank_dtype
return self_dtype
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[1]))
@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, sizes=[-1]))
def atenunflattenint〡dtype(self_rank_dtype: Tuple[int, int], dim: int, sizes: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype
return self_dtype

View File

@ -813,3 +813,56 @@ class ReshapeAliasCollapseModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ReshapeAliasCollapseModule())
def ReshapeAliasCollapseModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4))
# ==============================================================================
class UnflattenIntStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([3, 24, 5], torch.float32, True),
])
def forward(self, inputs):
return torch.ops.aten.unflatten(inputs, 1, [2, 4, 3])
@register_test_case(module_factory=lambda: UnflattenIntStaticModule())
def UnflattenIntStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 24, 5))
class UnflattenIntNegativeOneDimStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 12, 3], torch.float32, True),
])
def forward(self, inputs):
return torch.ops.aten.unflatten(inputs, -2, [2, 2, 3, 1, 1])
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneDimStaticModule())
def UnflattenIntNegativeOneDimStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 12, 3))
class UnflattenIntNegativeOneSizeStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 12, 3], torch.float32, True),
])
def forward(self, inputs):
return torch.ops.aten.unflatten(inputs, -2, [2, -1, 3, 1, 1])
@register_test_case(module_factory=lambda: UnflattenIntNegativeOneSizeStaticModule())
def UnflattenIntNegativeOneSizeStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 12, 3))