[Torch Dialect] Add split.tensor support + recompose rules (#2102)

* add split.tensor support + recompose rules

* add e2e test

* address comments

* address comments

* erase op in recomposeOp

---------

Co-authored-by: zhekun.zhang <zhekun.zhang@bytedance.com>
pull/2158/head snapshot-20230524.848
Zhekun Zhang 2023-05-23 12:43:33 -07:00 committed by GitHub
parent 080fad7c07
commit a426363b7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 93 additions and 0 deletions

View File

@ -8,6 +8,7 @@ blacklist:
- index_put_ # Error: TODO not sure if there are other valid types to handle here
# Ops with list of tensors output
- split.Tensor
- unbind.int
# Additional ops which autogen is supported for but don't compile yet

View File

@ -726,6 +726,7 @@ STABLEHLO_PASS_SET = {
"PrimsViewOfModule_basic",
"PrimsViewOfZeroRankModule_basic",
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
}
@ -1012,6 +1013,7 @@ TOSA_PASS_SET = {
"TensorsConcatStaticModule_basic",
"TensorsConcatNegativeDimStaticModule_basic",
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
}
LTC_XFAIL_SET = {
@ -1191,6 +1193,7 @@ LTC_XFAIL_SET = {
"AtenComplexImagModule_basic",
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic",
"SplitTensorGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
}

View File

@ -9566,6 +9566,30 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [
}];
}
def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::split.Tensor : (Tensor, int, int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$split_size,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSplitTensorOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenSplitTensorOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
AllowsTypeRefinement,
ReadOnly

View File

@ -181,6 +181,48 @@ public:
return success();
}
};
class RecomposeSplitTensorGetItemOp
: public OpRewritePattern<Aten__Getitem__TOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp
auto splitTensorOp =
dyn_cast<AtenSplitTensorOp>(op.getList().getDefiningOp());
if (!splitTensorOp)
return failure();
if (isListPotentiallyMutated(splitTensorOp.getResult()))
return failure();
int64_t index;
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
int64_t splitSize;
if (!matchPattern(splitTensorOp.getSplitSize(),
m_TorchConstantInt(&splitSize)))
return rewriter.notifyMatchFailure(
op,
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
Location loc = op.getLoc();
Value step =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(index * splitSize));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(index * splitSize + splitSize));
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, op.getResult().getType(), splitTensorOp.getSelf(),
splitTensorOp.getDim(), start, end, step);
rewriter.replaceOp(op, sliceTensorOp);
if (splitTensorOp.getResult().use_empty())
rewriter.eraseOp(splitTensorOp);
return success();
}
};
} // namespace
namespace {
@ -194,6 +236,7 @@ public:
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
patterns.add<RecomposeSelectFill_>(context);
patterns.add<RecomposeSplitTensorGetItemOp>(context);
patterns.add<RecomposeUnbindListUnpack>(context);
patterns.add<RecomposeUnbindGetItem>(context);

View File

@ -590,6 +590,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::any.bool : (bool[]) -> (bool)")
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
# Str ops.

View File

@ -581,3 +581,24 @@ class UnbindIntGetItem_Module(torch.nn.Module):
@register_test_case(module_factory=lambda: UnbindIntGetItem_Module())
def UnbindIntGetItem_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
# ==============================================================================
class SplitTensorGetItem_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([2, 3, 4], torch.float32, True),
])
def forward(self, x):
splits = torch.ops.aten.split(x, 1, 0)
return torch.ops.aten.sub(splits[0], splits[1])
@register_test_case(module_factory=lambda: SplitTensorGetItem_Module())
def SplitTensorGetItem_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))