[Torch Dialect] Support recompose aten.split.Tensor + prim.ListUnpack (#2192)

pull/2194/head
Yuanqiang Liu 2023-06-07 01:38:04 +08:00 committed by GitHub
parent e29c5e8003
commit faec8698ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 3 deletions

View File

@ -775,6 +775,7 @@ STABLEHLO_PASS_SET = {
"PrimsViewOfZeroRankModule_basic",
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",
@ -1074,6 +1075,7 @@ TOSA_PASS_SET = {
"TensorsConcatNegativeDimStaticModule_basic",
"AtenComplex64Module_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"ChunkListUnpack_Module_basic",
"ChunkListUnpackUneven_Module_basic",
}
@ -1261,6 +1263,7 @@ LTC_XFAIL_SET = {
"AtenComplexRealModule_basic",
"AtenComplexViewModule_basic",
"SplitTensorGetItem_Module_basic",
"SplitTensorListUnpackModule_basic",
"UnbindIntListUnpack_Module_basic",
"UnbindIntGetItem_Module_basic",
"ChunkListUnpack_Module_basic",

View File

@ -3172,6 +3172,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
// support for end < 0
end = toPositiveDim(end, selfType.getShape()[dim]);
// support for end out of upper bound
end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end);
// FIXME: add support for start < 0 and end < start
if (end < start)

View File

@ -246,6 +246,52 @@ public:
}
};
class RecomposeSplitTensorListUnpack
: public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps
auto splitTensorOp =
dyn_cast<AtenSplitTensorOp>(op.getOperand().getDefiningOp());
if (!splitTensorOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
if (isListPotentiallyMutated(splitTensorOp.getResult()))
return rewriter.notifyMatchFailure(
op, "SplitTensorOp result is potentially mutated");
int64_t splitSize;
if (!matchPattern(splitTensorOp.getSplitSize(),
m_TorchConstantInt(&splitSize)))
return rewriter.notifyMatchFailure(
op,
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
Location loc = op.getLoc();
Value step =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> slices;
for (size_t i = 0; i < op.getNumResults(); i++) {
auto resultTy = op.getResult(i).getType();
auto start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * splitSize));
auto end = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((i + 1) * splitSize));
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
loc, resultTy, splitTensorOp.getSelf(), splitTensorOp.getDim(), start,
end, step);
slices.push_back(sliceTensorOp);
}
rewriter.replaceOp(op, slices);
// erase splitTensorOp if no user left
if (splitTensorOp.getResult().use_empty())
rewriter.eraseOp(splitTensorOp);
return success();
}
};
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
@ -312,6 +358,7 @@ public:
patterns.add<RecomposeSliceCopy_>(context);
patterns.add<RecomposeSelectFill_>(context);
patterns.add<RecomposeSplitTensorGetItemOp>(context);
patterns.add<RecomposeSplitTensorListUnpack>(context);
patterns.add<RecomposeUnbindListUnpack>(context);
patterns.add<RecomposeUnbindGetItem>(context);
patterns.add<RecomposeChunkListUnpack>(context);

View File

@ -665,15 +665,34 @@ class SplitTensorGetItem_Module(torch.nn.Module):
@export
@annotate_args([
None,
([2, 3, 4], torch.float32, True),
([3, 3, 4], torch.float32, True),
])
def forward(self, x):
splits = torch.ops.aten.split(x, 1, 0)
splits = torch.ops.aten.split(x, 2, 0)
return torch.ops.aten.sub(splits[0], splits[1])
@register_test_case(module_factory=lambda: SplitTensorGetItem_Module())
def SplitTensorGetItem_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3, 4))
module.forward(tu.rand(3, 3, 4))
# ==============================================================================
class SplitTensorListUnpackModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, 3, 4], torch.float32, True),
])
def forward(self, x):
x1, x2, x3 = torch.ops.aten.split(x, 2, 0)
return x1 + x2 + x3
@register_test_case(module_factory=lambda: SplitTensorListUnpackModule())
def SplitTensorListUnpackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 3, 4))
# ==============================================================================