[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp… (#3401)

…lit.sections

* support recompose to aten.split.with_sizes and
aten.tensor_split.sections
* fix recompose of aten.chunk
pull/3405/head
Yuanqiang Liu 2024-05-31 09:56:47 +08:00 committed by GitHub
parent 074098d20c
commit 4e05e2cd1e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 339 additions and 43 deletions

View File

@ -13526,6 +13526,31 @@ def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
let hasCanonicalizer = 1;
}
def Torch_AtenTensorSplitSectionsOp : Torch_Op<"aten.tensor_split.sections", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
Torch_IntType:$sections,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenTensorSplitSectionsOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenTensorSplitSectionsOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [

View File

@ -3050,6 +3050,19 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
});
}
//===----------------------------------------------------------------------===//
// AtenSplitSizesOp
//===----------------------------------------------------------------------===//
void AtenSplitSizesOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenSplitSizesOp op, PatternRewriter &rewriter) {
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
return success();
});
}
//===----------------------------------------------------------------------===//
// AtenIsFloatingPointOp
//===----------------------------------------------------------------------===//

View File

@ -862,18 +862,6 @@ class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
};
} // namespace
namespace {
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSplitWithSizesOp
: public OpRewritePattern<AtenSplitWithSizesOp> {
@ -8084,7 +8072,6 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);

View File

@ -13,6 +13,7 @@
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include <numeric>
using namespace mlir;
using namespace mlir::torch;
@ -164,7 +165,7 @@ public:
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenUnbindOp + PrimListUnpackOp to select.int
auto unbindOp = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
auto unbindOp = op.getOperand().getDefiningOp<AtenUnbindIntOp>();
if (!unbindOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
if (isListPotentiallyMutated(unbindOp.getResult()))
@ -207,7 +208,7 @@ public:
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
PatternRewriter &rewriter) const override {
// recompose AtenUnbindIntOp + __getitem__t to select.int
auto unbind = dyn_cast<AtenUnbindIntOp>(op.getList().getDefiningOp());
auto unbind = op.getList().getDefiningOp<AtenUnbindIntOp>();
if (!unbind)
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
if (isListPotentiallyMutated(unbind.getResult()))
@ -243,15 +244,14 @@ public:
}
};
class RecomposeSplitTensorGetItemOp
class RecomposeSplitTensorGetItem
: public OpRewritePattern<Aten__Getitem__TOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp
auto splitTensorOp =
dyn_cast<AtenSplitTensorOp>(op.getList().getDefiningOp());
auto splitTensorOp = op.getList().getDefiningOp<AtenSplitTensorOp>();
if (!splitTensorOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
if (isListPotentiallyMutated(splitTensorOp.getResult()))
@ -308,8 +308,7 @@ public:
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps
auto splitTensorOp =
dyn_cast<AtenSplitTensorOp>(op.getOperand().getDefiningOp());
auto splitTensorOp = op.getOperand().getDefiningOp<AtenSplitTensorOp>();
if (!splitTensorOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
if (isListPotentiallyMutated(splitTensorOp.getResult()))
@ -362,6 +361,78 @@ public:
}
};
class RecomposeSplitWithSizesGetItem
: public OpRewritePattern<Aten__Getitem__TOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitWithSizes + __getitem__t to AtenSliceTensorOp
auto splitWithSizesOp = op.getList().getDefiningOp<AtenSplitWithSizesOp>();
if (!splitWithSizesOp)
return rewriter.notifyMatchFailure(op,
"Input is not AtenSplitWithSizesOp");
if (isListPotentiallyMutated(splitWithSizesOp.getResult()))
return rewriter.notifyMatchFailure(
op, "AtenSplitWithSizesOp result is potentially mutated");
if (isListPotentiallyMutated(splitWithSizesOp.getSplitSizes())) {
return rewriter.notifyMatchFailure(
op, "splitWithSizesOp's split_sizes is potentially mutated");
}
SmallVector<int64_t> splitSizes;
if (!matchPattern(splitWithSizesOp.getSplitSizes(),
m_TorchListOfConstantInts(splitSizes))) {
return rewriter.notifyMatchFailure(
op, "split_sizes must be list of constant int");
}
int64_t index;
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
index = toPositiveDim(index, splitSizes.size());
if (!isValidDim(index, splitSizes.size()))
return rewriter.notifyMatchFailure(
op, "Expected `idx` in range of split_sizes");
Location loc = op.getLoc();
Value input = splitWithSizesOp.getSelf();
Value dim = splitWithSizesOp.getDim();
// add runtime.assert to check dimension constraint
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
int64_t sumSplitSize =
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sumSplitSize));
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, totalSize, cstSumSplitSize);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr("split dim must be sum of split_sizes"));
// replace with AtenSliceTensorOp
SmallVector<int64_t> boundaryOfSliceOp(splitSizes.size() + 1, 0);
for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) {
boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1];
}
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
auto start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index]));
auto end = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1]));
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, op.getType(), input, dim, start, end, /*step=*/cstOne);
rewriter.replaceOp(op, slice);
// erase splitOp if no user left
if (splitWithSizesOp.getResult().use_empty())
rewriter.eraseOp(splitWithSizesOp);
return success();
}
};
class RecomposeSplitWithSizesListUnpack
: public OpRewritePattern<PrimListUnpackOp> {
public:
@ -369,8 +440,7 @@ public:
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps
auto splitOp =
dyn_cast<AtenSplitWithSizesOp>(op.getOperand().getDefiningOp());
auto splitOp = op.getOperand().getDefiningOp<AtenSplitWithSizesOp>();
if (!splitOp) {
return rewriter.notifyMatchFailure(op,
"Input is not AtenSplitWithSizesOp");
@ -390,20 +460,11 @@ public:
op, "split_sizes is not from PrimListConstructOp");
}
int64_t sumSplitSize = 0;
SmallVector<int64_t> splitSizes;
for (auto operand : splitSizesConstruct.getOperands()) {
int64_t value = -1;
// TODO: support when split_sizes are not constant int
if (!matchPattern(operand, m_TorchConstantInt(&value))) {
if (!matchPattern(splitOp.getSplitSizes(),
m_TorchListOfConstantInts(splitSizes))) {
return rewriter.notifyMatchFailure(
op, "one of split_sizes is not constant int");
}
if (value < 0) {
return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0");
}
sumSplitSize += value;
splitSizes.push_back(value);
op, "split_sizes must be list of constant int");
}
if (splitSizes.size() != op.getNumResults()) {
return rewriter.notifyMatchFailure(
@ -416,6 +477,8 @@ public:
// add runtime.assert to check rank constraint
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
int64_t sumSplitSize =
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(sumSplitSize));
Value eqOrNot =
@ -450,13 +513,156 @@ public:
}
};
class RecomposeTensorSplitSectionsGetItem
: public OpRewritePattern<Aten__Getitem__TOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
PatternRewriter &rewriter) const override {
// recompose AtenTensorSplitSectionsOp + __getitem__t to AtenSliceTensorOp
auto splitOp = op.getList().getDefiningOp<AtenTensorSplitSectionsOp>();
if (!splitOp)
return rewriter.notifyMatchFailure(
op, "Input is not AtenTensorSplitSectionsOp");
if (isListPotentiallyMutated(splitOp.getResult()))
return rewriter.notifyMatchFailure(
op, "AtenTensorSplitSectionsOp result is potentially mutated");
int64_t sections;
if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(&sections)))
return rewriter.notifyMatchFailure(
op, "Expected `sections` of AtenTensorSplitSectionsOp to be a "
"constant int");
int64_t index;
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
index = toPositiveDim(index, sections);
if (!isValidDim(index, sections))
return rewriter.notifyMatchFailure(
op, "Expected `idx` in range of split_sizes");
Location loc = op.getLoc();
Value input = splitOp.getSelf();
Value dim = splitOp.getDim();
// only recompose to slice when split dim size is static, otherwise we need
// control flow like prim.if
Value dimSizeValue = rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim);
int64_t splitDimSize;
if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize)))
return rewriter.notifyMatchFailure(splitOp,
"split dim size must be static");
int64_t chunkSize = splitDimSize / sections;
int64_t remain = splitDimSize % sections;
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value result;
if (index < remain) {
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1)));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1)));
result = rewriter.create<AtenSliceTensorOp>(loc, op.getType(), input, dim,
start, end,
/*step=*/cstOne);
} else {
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(index * chunkSize + remain));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain));
result = rewriter.create<AtenSliceTensorOp>(loc, op.getType(), input, dim,
start, end,
/*step=*/cstOne);
}
rewriter.replaceOp(op, result);
// erase AtenTensorSplitSectionsOp if no user left
if (splitOp.getResult().use_empty())
rewriter.eraseOp(splitOp);
return success();
}
};
class RecomposeTensorSplitSectionsListUnpack
: public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenTensorSplitSectionsOp + PrimListUnpackOp to
// AtenSliceTensorOps
auto splitOp = op.getOperand().getDefiningOp<AtenTensorSplitSectionsOp>();
if (!splitOp)
return rewriter.notifyMatchFailure(
op, "Input is not AtenTensorSplitSectionsOp");
if (isListPotentiallyMutated(splitOp.getResult()))
return rewriter.notifyMatchFailure(
op, "AtenTensorSplitSectionsOp result is potentially mutated");
int64_t sections;
if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(&sections)))
return rewriter.notifyMatchFailure(
op, "Expected `sections` of AtenTensorSplitSectionsOp to be a "
"constant int");
if (op->getNumResults() != sections)
return rewriter.notifyMatchFailure(
op, "`sections` must be same as ListUnpack's NumResults");
Location loc = op.getLoc();
Value input = splitOp.getSelf();
Value dim = splitOp.getDim();
// only recompose to slice when split dim size is static, otherwise we need
// control flow like prim.if
Value dimSizeValue = rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim);
int64_t splitDimSize;
if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize)))
return rewriter.notifyMatchFailure(splitOp,
"split dim size must be static");
int64_t chunkSize = splitDimSize / sections;
int64_t remain = splitDimSize % sections;
Value cstOne =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
SmallVector<Value> results;
for (int64_t i = 0; i < sections; i++) {
if (i < remain) {
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1)));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1)));
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, op.getResult(i).getType(), input, dim, start, end,
/*step=*/cstOne);
results.push_back(slice);
} else {
Value start = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(i * chunkSize + remain));
Value end = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain));
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, op.getResult(i).getType(), input, dim, start, end,
/*step=*/cstOne);
results.push_back(slice);
}
}
rewriter.replaceOp(op, results);
// erase AtenTensorSplitSectionsOp if no user left
if (splitOp.getResult().use_empty())
rewriter.eraseOp(splitOp);
return success();
}
};
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListUnpackOp op,
PatternRewriter &rewriter) const override {
// recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps
auto chunkOp = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
auto chunkOp = op.getOperand().getDefiningOp<AtenChunkOp>();
if (!chunkOp)
return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp");
if (isListPotentiallyMutated(chunkOp.getResult()))
@ -470,10 +676,13 @@ public:
// chunkSize = floordiv(totalSize + chunks - 1, chunks)
Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks);
// add runtime.assert to check chunks == NumResults
// add runtime.assert to check floordiv(totalSize + chunkSize - 1,
// chunkSize) == NumResults
Value cstNumResults = rewriter.create<ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, chunks, cstNumResults);
Value realChunks = getIntCeilDiv(rewriter, loc, totalSize, chunkSize);
Value eqOrNot =
rewriter.create<AtenEqIntOp>(loc, realChunks, cstNumResults);
rewriter.create<RuntimeAssertOp>(
loc, eqOrNot,
rewriter.getStringAttr(
@ -521,9 +730,15 @@ public:
// pattern.add calls go here
patterns.add<RecomposeSliceCopy_>(context);
patterns.add<RecomposeSelectFill_>(context);
patterns.add<RecomposeSplitTensorGetItemOp>(context);
// TODO: cloud move these patterns to Decompose pass, but should handle
// shape and value semantics carefully
patterns.add<RecomposeSplitTensorGetItem>(context);
patterns.add<RecomposeSplitTensorListUnpack>(context);
patterns.add<RecomposeSplitWithSizesGetItem>(context);
patterns.add<RecomposeSplitWithSizesListUnpack>(context);
patterns.add<RecomposeTensorSplitSectionsGetItem>(context);
patterns.add<RecomposeTensorSplitSectionsListUnpack>(context);
patterns.add<RecomposeUnbindListUnpack>(context);
patterns.add<RecomposeUnbindGetItem>(context);
patterns.add<RecomposeChunkListUnpack>(context);

View File

@ -21,7 +21,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"SplitWithSizes_Module_basic",
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
# these interpolate tests are added specifically to test onnx.Resize.
"InterpolateDynamicModule_sizes_bilinear",
@ -817,6 +816,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
}
STABLEHLO_PASS_SET = {
"SplitWithSizes_Module_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
"AtenLinear1D_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
@ -1456,6 +1458,8 @@ STABLEHLO_CRASHING_SET = set()
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
"AtenLinear2D_basic",
"AtenLinear3DBias_basic",
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
@ -2594,6 +2598,10 @@ ONNX_XFAIL_SET = {
"_ConvolutionDeprecated2DDeterministicModule_basic",
"_SoftmaxModule_basic",
# Failure - onnx_import
# Failure - onnx_lowering: onnx.SplitToSequence
"ChunkListUnpackUneven_Module_basic",
"TensorSplitSections_GetItemModule_basic",
"TensorSplitSections_ListUnpackModule_basic",
# Failure - onnx_lowering: onnx.AveragePool
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
# these diagonal modules are currently failing due to dynamic shape.

View File

@ -969,7 +969,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])")
emit(
"aten::split.sizes : (Tensor, int[], int) -> (Tensor[])", has_canonicalizer=True
)
emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")

View File

@ -995,8 +995,8 @@ class ChunkListUnpackUneven_Module(torch.nn.Module):
]
)
def forward(self, x):
chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1)
return torch.ops.aten.add(chunk_0, chunk_1), chunk_2
a0, a1, a2, a3, a4 = torch.chunk(x, 6, 1)
return a0, a1, a2, a3, a4
@register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module())
@ -1076,3 +1076,48 @@ class SplitWithSizes_Module(torch.nn.Module):
@register_test_case(module_factory=lambda: SplitWithSizes_Module())
def SplitWithSizes_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2))
# ==============================================================================
class TensorSplitSections_GetItemModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 5], torch.float32, True),
]
)
def forward(self, x):
split = torch.tensor_split(x, 3, dim=1)
return split[0], split[1], split[2]
@register_test_case(module_factory=lambda: TensorSplitSections_GetItemModule())
def TensorSplitSections_GetItemModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5))
class TensorSplitSections_ListUnpackModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args(
[
None,
([2, 5], torch.float32, True),
]
)
def forward(self, x):
a, b, c, d = torch.tensor_split(x, 4, dim=1)
return a, b, c, d
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 5))