mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] Support recompose aten.split.Tensor + prim.ListUnpack (#2192)
parent
e29c5e8003
commit
faec8698ea
|
@ -775,6 +775,7 @@ STABLEHLO_PASS_SET = {
|
|||
"PrimsViewOfZeroRankModule_basic",
|
||||
"AtenComplex64Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
|
@ -1074,6 +1075,7 @@ TOSA_PASS_SET = {
|
|||
"TensorsConcatNegativeDimStaticModule_basic",
|
||||
"AtenComplex64Module_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
}
|
||||
|
@ -1261,6 +1263,7 @@ LTC_XFAIL_SET = {
|
|||
"AtenComplexRealModule_basic",
|
||||
"AtenComplexViewModule_basic",
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
|
|
|
@ -3172,6 +3172,8 @@ LogicalResult ConvertAtenOp<AtenSliceTensorOp>::matchAndRewrite(
|
|||
return rewriter.notifyMatchFailure(op, "end must be a Scalar constant");
|
||||
// support for end < 0
|
||||
end = toPositiveDim(end, selfType.getShape()[dim]);
|
||||
// support for end out of upper bound
|
||||
end = (end > selfType.getShape()[dim] ? selfType.getShape()[dim] : end);
|
||||
|
||||
// FIXME: add support for start < 0 and end < start
|
||||
if (end < start)
|
||||
|
|
|
@ -246,6 +246,52 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RecomposeSplitTensorListUnpack
|
||||
: public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto splitTensorOp =
|
||||
dyn_cast<AtenSplitTensorOp>(op.getOperand().getDefiningOp());
|
||||
if (!splitTensorOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
|
||||
if (isListPotentiallyMutated(splitTensorOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "SplitTensorOp result is potentially mutated");
|
||||
|
||||
int64_t splitSize;
|
||||
if (!matchPattern(splitTensorOp.getSplitSize(),
|
||||
m_TorchConstantInt(&splitSize)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"Expected `SplitSize` of `AtenSplitTensorOp` to be a constant int");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value step =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
|
||||
SmallVector<Value> slices;
|
||||
for (size_t i = 0; i < op.getNumResults(); i++) {
|
||||
auto resultTy = op.getResult(i).getType();
|
||||
auto start = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i * splitSize));
|
||||
auto end = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((i + 1) * splitSize));
|
||||
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, resultTy, splitTensorOp.getSelf(), splitTensorOp.getDim(), start,
|
||||
end, step);
|
||||
slices.push_back(sliceTensorOp);
|
||||
}
|
||||
rewriter.replaceOp(op, slices);
|
||||
// erase splitTensorOp if no user left
|
||||
if (splitTensorOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitTensorOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -312,6 +358,7 @@ public:
|
|||
patterns.add<RecomposeSliceCopy_>(context);
|
||||
patterns.add<RecomposeSelectFill_>(context);
|
||||
patterns.add<RecomposeSplitTensorGetItemOp>(context);
|
||||
patterns.add<RecomposeSplitTensorListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindGetItem>(context);
|
||||
patterns.add<RecomposeChunkListUnpack>(context);
|
||||
|
|
|
@ -665,15 +665,34 @@ class SplitTensorGetItem_Module(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([2, 3, 4], torch.float32, True),
|
||||
([3, 3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
splits = torch.ops.aten.split(x, 1, 0)
|
||||
splits = torch.ops.aten.split(x, 2, 0)
|
||||
return torch.ops.aten.sub(splits[0], splits[1])
|
||||
|
||||
@register_test_case(module_factory=lambda: SplitTensorGetItem_Module())
|
||||
def SplitTensorGetItem_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 3, 4))
|
||||
module.forward(tu.rand(3, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SplitTensorListUnpackModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, 3, 4], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
x1, x2, x3 = torch.ops.aten.split(x, 2, 0)
|
||||
return x1 + x2 + x3
|
||||
|
||||
@register_test_case(module_factory=lambda: SplitTensorListUnpackModule())
|
||||
def SplitTensorListUnpackModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 3, 4))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
|
Loading…
Reference in New Issue