[Torch Dialect] support aten.split_with_sizes (#2431)

* [Torch Dialect] support aten.split_with_sizes

* update
pull/2436/head
Yuanqiang Liu 2023-09-04 09:59:26 +08:00 committed by GitHub
parent cd1c7df8be
commit e9ab8ceb1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 138 additions and 1 deletions

View File

@ -5,6 +5,7 @@ blacklist:
# Ops with list of tensors output
- split.Tensor
- split_with_sizes
- unbind.int
- chunk

View File

@ -872,6 +872,7 @@ STABLEHLO_PASS_SET = {
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",
@ -1216,6 +1217,7 @@ TOSA_PASS_SET = {
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitTensorLastSmallerModule_basic",
"SplitWithSizesListUnpackModule_basic",
"ChunkListUnpack_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"TupleModule_basic",

View File

@ -11017,6 +11017,30 @@ def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
}];
}
def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$split_sizes,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSplitWithSizesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenSplitWithSizesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
AllowsTypeRefinement,
ReadOnly

View File

@ -1444,7 +1444,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
return nullptr;
// If any operand is a constant true, return true.
for (auto operand : inputConstruct.getOperands()) {
bool b;
bool b = false;
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
return getI1IntegerAttr(getContext(), true);
}

View File

@ -363,6 +363,94 @@ public:
}
};
class RecomposeSplitWithSizesListUnpack
: public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps
auto splitOp =
dyn_cast<AtenSplitWithSizesOp>(op.getOperand().getDefiningOp());
if (!splitOp) {
return rewriter.notifyMatchFailure(op,
"Input is not AtenSplitWithSizesOp");
}
if (isListPotentiallyMutated(splitOp.getResult())) {
return rewriter.notifyMatchFailure(
op, "splitWithSizesOp result is potentially mutated");
}
if (isListPotentiallyMutated(splitOp.getSplitSizes())) {
return rewriter.notifyMatchFailure(
op, "splitWithSizesOp's split_sizes is potentially mutated");
}
auto splitSizesConstruct =
splitOp.getSplitSizes().getDefiningOp<Torch::PrimListConstructOp>();
if (!splitSizesConstruct) {
return rewriter.notifyMatchFailure(
op, "split_sizes is not from PrimListConstructOp");
}
int64_t sumSplitSize = 0;
SmallVector<int64_t> splitSizes;
for (auto operand : splitSizesConstruct.getOperands()) {
int64_t value = -1;
// TODO: support when split_sizes are not constant int
if (!matchPattern(operand, m_TorchConstantInt(&value))) {
return rewriter.notifyMatchFailure(
op, "one of split_sizes is not constant int");
}
if (value < 0) {
return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0");
}
sumSplitSize += value;
splitSizes.push_back(value);
}
if (splitSizes.size() != op.getNumResults()) {
return rewriter.notifyMatchFailure(
op, "split_sizes must be same as splitOp result size");
}
Location loc = op.getLoc();
Value input = splitOp.getSelf();
Value dim = splitOp.getDim();
// add runtime.assert to check rank constraint
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sumSplitSize));
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, totalSize, cstSumSplitSize);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("split dim must be sum of split_sizes"));
// calculate slice op's lower bound and up bound
SmallVector<int64_t> boundaryOfSliceOp(splitSizes.size() + 1, 0);
for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) {
boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1];
}
SmallVector<Value> slices;
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
for (size_t i = 0; i < op.getNumResults(); i++) {
auto resultTy = op.getResult(i).getType();
auto start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i]));
auto end = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1])));
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
slices.push_back(sliceTensorOp);
}
rewriter.replaceOp(op, slices);
// erase splitOp if no user left
if (splitOp.getResult().use_empty())
rewriter.eraseOp(splitOp);
return success();
}
};
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
@ -436,6 +524,7 @@ public:
patterns.add<RecomposeSelectFill_>(context);
patterns.add<RecomposeSplitTensorGetItemOp>(context);
patterns.add<RecomposeSplitTensorListUnpack>(context);
patterns.add<RecomposeSplitWithSizesListUnpack>(context);
patterns.add<RecomposeUnbindListUnpack>(context);
patterns.add<RecomposeUnbindGetItem>(context);
patterns.add<RecomposeChunkListUnpack>(context);

View File

@ -651,6 +651,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")

View File

@ -800,6 +800,26 @@ def SplitTensorNegativeDimModule_basic(module, tu: TestUtils):
# ==============================================================================
class SplitWithSizesListUnpackModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([10, 12], torch.float32, True)
])
def forward(self, x):
s0, s1, s2 = torch.ops.aten.split_with_sizes(x, [3, 4, 5], -1)
return (s0, s1, s2)
@register_test_case(module_factory=lambda: SplitWithSizesListUnpackModule())
def SplitWithSizesListUnpackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 12))
# ==============================================================================
class ChunkListUnpack_Module(torch.nn.Module):
def __init__(self):
super().__init__()