mirror of https://github.com/llvm/torch-mlir
[Torch Dialect] support aten.split_with_sizes (#2431)
* [Torch Dialect] support aten.split_with_sizes * updatepull/2436/head
parent
cd1c7df8be
commit
e9ab8ceb1c
|
@ -5,6 +5,7 @@ blacklist:
|
|||
|
||||
# Ops with list of tensors output
|
||||
- split.Tensor
|
||||
- split_with_sizes
|
||||
- unbind.int
|
||||
- chunk
|
||||
|
||||
|
|
|
@ -872,6 +872,7 @@ STABLEHLO_PASS_SET = {
|
|||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
|
@ -1216,6 +1217,7 @@ TOSA_PASS_SET = {
|
|||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"ChunkListUnpack_Module_basic",
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"TupleModule_basic",
|
||||
|
|
|
@ -11017,6 +11017,30 @@ def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$split_sizes,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSplitWithSizesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenSplitWithSizesOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -1444,7 +1444,7 @@ OpFoldResult AtenAnyBoolOp::fold(FoldAdaptor adaptor) {
|
|||
return nullptr;
|
||||
// If any operand is a constant true, return true.
|
||||
for (auto operand : inputConstruct.getOperands()) {
|
||||
bool b;
|
||||
bool b = false;
|
||||
if (matchPattern(operand, m_TorchConstantBool(&b)) && b) {
|
||||
return getI1IntegerAttr(getContext(), true);
|
||||
}
|
||||
|
|
|
@ -363,6 +363,94 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RecomposeSplitWithSizesListUnpack
|
||||
: public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto splitOp =
|
||||
dyn_cast<AtenSplitWithSizesOp>(op.getOperand().getDefiningOp());
|
||||
if (!splitOp) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Input is not AtenSplitWithSizesOp");
|
||||
}
|
||||
if (isListPotentiallyMutated(splitOp.getResult())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "splitWithSizesOp result is potentially mutated");
|
||||
}
|
||||
if (isListPotentiallyMutated(splitOp.getSplitSizes())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "splitWithSizesOp's split_sizes is potentially mutated");
|
||||
}
|
||||
auto splitSizesConstruct =
|
||||
splitOp.getSplitSizes().getDefiningOp<Torch::PrimListConstructOp>();
|
||||
if (!splitSizesConstruct) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split_sizes is not from PrimListConstructOp");
|
||||
}
|
||||
|
||||
int64_t sumSplitSize = 0;
|
||||
SmallVector<int64_t> splitSizes;
|
||||
for (auto operand : splitSizesConstruct.getOperands()) {
|
||||
int64_t value = -1;
|
||||
// TODO: support when split_sizes are not constant int
|
||||
if (!matchPattern(operand, m_TorchConstantInt(&value))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "one of split_sizes is not constant int");
|
||||
}
|
||||
if (value < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0");
|
||||
}
|
||||
sumSplitSize += value;
|
||||
splitSizes.push_back(value);
|
||||
}
|
||||
if (splitSizes.size() != op.getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split_sizes must be same as splitOp result size");
|
||||
}
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = splitOp.getSelf();
|
||||
Value dim = splitOp.getDim();
|
||||
|
||||
// add runtime.assert to check rank constraint
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(sumSplitSize));
|
||||
Value eqOrNot =
|
||||
rewriter.create<AtenEqIntOp>(loc, totalSize, cstSumSplitSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr("split dim must be sum of split_sizes"));
|
||||
|
||||
// calculate slice op's lower bound and up bound
|
||||
SmallVector<int64_t> boundaryOfSliceOp(splitSizes.size() + 1, 0);
|
||||
for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) {
|
||||
boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1];
|
||||
}
|
||||
SmallVector<Value> slices;
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
for (size_t i = 0; i < op.getNumResults(); i++) {
|
||||
auto resultTy = op.getResult(i).getType();
|
||||
auto start = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[i]));
|
||||
auto end = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((boundaryOfSliceOp[i + 1])));
|
||||
Value sliceTensorOp = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, resultTy, input, dim, start, end, /*step=*/cstOne);
|
||||
slices.push_back(sliceTensorOp);
|
||||
}
|
||||
rewriter.replaceOp(op, slices);
|
||||
// erase splitOp if no user left
|
||||
if (splitOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
|
@ -436,6 +524,7 @@ public:
|
|||
patterns.add<RecomposeSelectFill_>(context);
|
||||
patterns.add<RecomposeSplitTensorGetItemOp>(context);
|
||||
patterns.add<RecomposeSplitTensorListUnpack>(context);
|
||||
patterns.add<RecomposeSplitWithSizesListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindGetItem>(context);
|
||||
patterns.add<RecomposeChunkListUnpack>(context);
|
||||
|
|
|
@ -651,6 +651,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True)
|
||||
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)")
|
||||
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
|
||||
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
|
||||
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")
|
||||
|
||||
|
|
|
@ -800,6 +800,26 @@ def SplitTensorNegativeDimModule_basic(module, tu: TestUtils):
|
|||
|
||||
# ==============================================================================
|
||||
|
||||
class SplitWithSizesListUnpackModule(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([10, 12], torch.float32, True)
|
||||
])
|
||||
def forward(self, x):
|
||||
s0, s1, s2 = torch.ops.aten.split_with_sizes(x, [3, 4, 5], -1)
|
||||
return (s0, s1, s2)
|
||||
|
||||
@register_test_case(module_factory=lambda: SplitWithSizesListUnpackModule())
|
||||
def SplitWithSizesListUnpackModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(10, 12))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class ChunkListUnpack_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
Loading…
Reference in New Issue