mirror of https://github.com/llvm/torch-mlir
[Torch] support recompose of aten.split.with_sizes and aten.tensor_sp… (#3401)
…lit.sections * support recompose to aten.split.with_sizes and aten.tensor_split.sections * fix recompose of aten.chunkpull/3405/head
parent
074098d20c
commit
4e05e2cd1e
|
@ -13526,6 +13526,31 @@ def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
|
|||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Torch_AtenTensorSplitSectionsOp : Torch_Op<"aten.tensor_split.sections", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
Torch_IntType:$sections,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenTensorSplitSectionsOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenTensorSplitSectionsOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
|
||||
|
|
|
@ -3050,6 +3050,19 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
|
|||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenSplitSizesOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AtenSplitSizesOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add(+[](AtenSplitSizesOp op, PatternRewriter &rewriter) {
|
||||
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
|
||||
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
|
||||
return success();
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AtenIsFloatingPointOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -862,18 +862,6 @@ class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
|
||||
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSplitWithSizesOp
|
||||
: public OpRewritePattern<AtenSplitWithSizesOp> {
|
||||
|
@ -8084,7 +8072,6 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
|
||||
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
|
||||
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::torch;
|
||||
|
@ -164,7 +165,7 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenUnbindOp + PrimListUnpackOp to select.int
|
||||
auto unbindOp = dyn_cast<AtenUnbindIntOp>(op.getOperand().getDefiningOp());
|
||||
auto unbindOp = op.getOperand().getDefiningOp<AtenUnbindIntOp>();
|
||||
if (!unbindOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
|
||||
if (isListPotentiallyMutated(unbindOp.getResult()))
|
||||
|
@ -207,7 +208,7 @@ public:
|
|||
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenUnbindIntOp + __getitem__t to select.int
|
||||
auto unbind = dyn_cast<AtenUnbindIntOp>(op.getList().getDefiningOp());
|
||||
auto unbind = op.getList().getDefiningOp<AtenUnbindIntOp>();
|
||||
if (!unbind)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenUnbindIntOp");
|
||||
if (isListPotentiallyMutated(unbind.getResult()))
|
||||
|
@ -243,15 +244,14 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RecomposeSplitTensorGetItemOp
|
||||
class RecomposeSplitTensorGetItem
|
||||
: public OpRewritePattern<Aten__Getitem__TOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitTensorOp + __getitem__t to AtenSliceTensorOp
|
||||
auto splitTensorOp =
|
||||
dyn_cast<AtenSplitTensorOp>(op.getList().getDefiningOp());
|
||||
auto splitTensorOp = op.getList().getDefiningOp<AtenSplitTensorOp>();
|
||||
if (!splitTensorOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
|
||||
if (isListPotentiallyMutated(splitTensorOp.getResult()))
|
||||
|
@ -308,8 +308,7 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitTensorOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto splitTensorOp =
|
||||
dyn_cast<AtenSplitTensorOp>(op.getOperand().getDefiningOp());
|
||||
auto splitTensorOp = op.getOperand().getDefiningOp<AtenSplitTensorOp>();
|
||||
if (!splitTensorOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenSplitTensorOp");
|
||||
if (isListPotentiallyMutated(splitTensorOp.getResult()))
|
||||
|
@ -362,6 +361,78 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RecomposeSplitWithSizesGetItem
|
||||
: public OpRewritePattern<Aten__Getitem__TOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitWithSizes + __getitem__t to AtenSliceTensorOp
|
||||
auto splitWithSizesOp = op.getList().getDefiningOp<AtenSplitWithSizesOp>();
|
||||
if (!splitWithSizesOp)
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Input is not AtenSplitWithSizesOp");
|
||||
if (isListPotentiallyMutated(splitWithSizesOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenSplitWithSizesOp result is potentially mutated");
|
||||
if (isListPotentiallyMutated(splitWithSizesOp.getSplitSizes())) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "splitWithSizesOp's split_sizes is potentially mutated");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> splitSizes;
|
||||
if (!matchPattern(splitWithSizesOp.getSplitSizes(),
|
||||
m_TorchListOfConstantInts(splitSizes))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split_sizes must be list of constant int");
|
||||
}
|
||||
|
||||
int64_t index;
|
||||
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
|
||||
index = toPositiveDim(index, splitSizes.size());
|
||||
if (!isValidDim(index, splitSizes.size()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` in range of split_sizes");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = splitWithSizesOp.getSelf();
|
||||
Value dim = splitWithSizesOp.getDim();
|
||||
|
||||
// add runtime.assert to check dimension constraint
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
int64_t sumSplitSize =
|
||||
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);
|
||||
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(sumSplitSize));
|
||||
Value eqOrNot =
|
||||
rewriter.create<AtenEqIntOp>(loc, totalSize, cstSumSplitSize);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr("split dim must be sum of split_sizes"));
|
||||
|
||||
// replace with AtenSliceTensorOp
|
||||
SmallVector<int64_t> boundaryOfSliceOp(splitSizes.size() + 1, 0);
|
||||
for (size_t i = 1; i < boundaryOfSliceOp.size(); i++) {
|
||||
boundaryOfSliceOp[i] = boundaryOfSliceOp[i - 1] + splitSizes[i - 1];
|
||||
}
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
auto start = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index]));
|
||||
auto end = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(boundaryOfSliceOp[index + 1]));
|
||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, op.getType(), input, dim, start, end, /*step=*/cstOne);
|
||||
rewriter.replaceOp(op, slice);
|
||||
// erase splitOp if no user left
|
||||
if (splitWithSizesOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitWithSizesOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeSplitWithSizesListUnpack
|
||||
: public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
|
@ -369,8 +440,7 @@ public:
|
|||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenSplitWithSizesOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto splitOp =
|
||||
dyn_cast<AtenSplitWithSizesOp>(op.getOperand().getDefiningOp());
|
||||
auto splitOp = op.getOperand().getDefiningOp<AtenSplitWithSizesOp>();
|
||||
if (!splitOp) {
|
||||
return rewriter.notifyMatchFailure(op,
|
||||
"Input is not AtenSplitWithSizesOp");
|
||||
|
@ -390,20 +460,11 @@ public:
|
|||
op, "split_sizes is not from PrimListConstructOp");
|
||||
}
|
||||
|
||||
int64_t sumSplitSize = 0;
|
||||
SmallVector<int64_t> splitSizes;
|
||||
for (auto operand : splitSizesConstruct.getOperands()) {
|
||||
int64_t value = -1;
|
||||
// TODO: support when split_sizes are not constant int
|
||||
if (!matchPattern(operand, m_TorchConstantInt(&value))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "one of split_sizes is not constant int");
|
||||
}
|
||||
if (value < 0) {
|
||||
return rewriter.notifyMatchFailure(op, "all of split_sizes must > 0");
|
||||
}
|
||||
sumSplitSize += value;
|
||||
splitSizes.push_back(value);
|
||||
if (!matchPattern(splitOp.getSplitSizes(),
|
||||
m_TorchListOfConstantInts(splitSizes))) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "split_sizes must be list of constant int");
|
||||
}
|
||||
if (splitSizes.size() != op.getNumResults()) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
|
@ -416,6 +477,8 @@ public:
|
|||
|
||||
// add runtime.assert to check rank constraint
|
||||
Value totalSize = rewriter.create<AtenSizeIntOp>(loc, input, dim);
|
||||
int64_t sumSplitSize =
|
||||
std::accumulate(splitSizes.begin(), splitSizes.end(), 0);
|
||||
Value cstSumSplitSize = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(sumSplitSize));
|
||||
Value eqOrNot =
|
||||
|
@ -450,13 +513,156 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class RecomposeTensorSplitSectionsGetItem
|
||||
: public OpRewritePattern<Aten__Getitem__TOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(Aten__Getitem__TOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenTensorSplitSectionsOp + __getitem__t to AtenSliceTensorOp
|
||||
auto splitOp = op.getList().getDefiningOp<AtenTensorSplitSectionsOp>();
|
||||
if (!splitOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input is not AtenTensorSplitSectionsOp");
|
||||
if (isListPotentiallyMutated(splitOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenTensorSplitSectionsOp result is potentially mutated");
|
||||
|
||||
int64_t sections;
|
||||
if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `sections` of AtenTensorSplitSectionsOp to be a "
|
||||
"constant int");
|
||||
|
||||
int64_t index;
|
||||
if (!matchPattern(op.getIdx(), m_TorchConstantInt(&index)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` of `Aten__Getitem__TOp` to be a constant int");
|
||||
index = toPositiveDim(index, sections);
|
||||
if (!isValidDim(index, sections))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `idx` in range of split_sizes");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = splitOp.getSelf();
|
||||
Value dim = splitOp.getDim();
|
||||
|
||||
// only recompose to slice when split dim size is static, otherwise we need
|
||||
// control flow like prim.if
|
||||
Value dimSizeValue = rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim);
|
||||
int64_t splitDimSize;
|
||||
if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize)))
|
||||
return rewriter.notifyMatchFailure(splitOp,
|
||||
"split dim size must be static");
|
||||
|
||||
int64_t chunkSize = splitDimSize / sections;
|
||||
int64_t remain = splitDimSize % sections;
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value result;
|
||||
if (index < remain) {
|
||||
Value start = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(index * (chunkSize + 1)));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((index + 1) * (chunkSize + 1)));
|
||||
result = rewriter.create<AtenSliceTensorOp>(loc, op.getType(), input, dim,
|
||||
start, end,
|
||||
/*step=*/cstOne);
|
||||
} else {
|
||||
Value start = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(index * chunkSize + remain));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((index + 1) * chunkSize + remain));
|
||||
result = rewriter.create<AtenSliceTensorOp>(loc, op.getType(), input, dim,
|
||||
start, end,
|
||||
/*step=*/cstOne);
|
||||
}
|
||||
rewriter.replaceOp(op, result);
|
||||
// erase AtenTensorSplitSectionsOp if no user left
|
||||
if (splitOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeTensorSplitSectionsListUnpack
|
||||
: public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenTensorSplitSectionsOp + PrimListUnpackOp to
|
||||
// AtenSliceTensorOps
|
||||
auto splitOp = op.getOperand().getDefiningOp<AtenTensorSplitSectionsOp>();
|
||||
if (!splitOp)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Input is not AtenTensorSplitSectionsOp");
|
||||
if (isListPotentiallyMutated(splitOp.getResult()))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "AtenTensorSplitSectionsOp result is potentially mutated");
|
||||
|
||||
int64_t sections;
|
||||
if (!matchPattern(splitOp.getSections(), m_TorchConstantInt(§ions)))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Expected `sections` of AtenTensorSplitSectionsOp to be a "
|
||||
"constant int");
|
||||
if (op->getNumResults() != sections)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "`sections` must be same as ListUnpack's NumResults");
|
||||
|
||||
Location loc = op.getLoc();
|
||||
Value input = splitOp.getSelf();
|
||||
Value dim = splitOp.getDim();
|
||||
|
||||
// only recompose to slice when split dim size is static, otherwise we need
|
||||
// control flow like prim.if
|
||||
Value dimSizeValue = rewriter.createOrFold<AtenSizeIntOp>(loc, input, dim);
|
||||
int64_t splitDimSize;
|
||||
if (!matchPattern(dimSizeValue, m_TorchConstantInt(&splitDimSize)))
|
||||
return rewriter.notifyMatchFailure(splitOp,
|
||||
"split dim size must be static");
|
||||
|
||||
int64_t chunkSize = splitDimSize / sections;
|
||||
int64_t remain = splitDimSize % sections;
|
||||
Value cstOne =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
SmallVector<Value> results;
|
||||
for (int64_t i = 0; i < sections; i++) {
|
||||
if (i < remain) {
|
||||
Value start = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i * (chunkSize + 1)));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((i + 1) * (chunkSize + 1)));
|
||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, op.getResult(i).getType(), input, dim, start, end,
|
||||
/*step=*/cstOne);
|
||||
results.push_back(slice);
|
||||
} else {
|
||||
Value start = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(i * chunkSize + remain));
|
||||
Value end = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr((i + 1) * chunkSize + remain));
|
||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, op.getResult(i).getType(), input, dim, start, end,
|
||||
/*step=*/cstOne);
|
||||
results.push_back(slice);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOp(op, results);
|
||||
// erase AtenTensorSplitSectionsOp if no user left
|
||||
if (splitOp.getResult().use_empty())
|
||||
rewriter.eraseOp(splitOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class RecomposeChunkListUnpack : public OpRewritePattern<PrimListUnpackOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimListUnpackOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
// recompose AtenChunkOp + PrimListUnpackOp to AtenSliceTensorOps
|
||||
auto chunkOp = dyn_cast<AtenChunkOp>(op.getOperand().getDefiningOp());
|
||||
auto chunkOp = op.getOperand().getDefiningOp<AtenChunkOp>();
|
||||
if (!chunkOp)
|
||||
return rewriter.notifyMatchFailure(op, "Input is not AtenChunkOp");
|
||||
if (isListPotentiallyMutated(chunkOp.getResult()))
|
||||
|
@ -470,10 +676,13 @@ public:
|
|||
// chunkSize = floordiv(totalSize + chunks - 1, chunks)
|
||||
Value chunkSize = getIntCeilDiv(rewriter, loc, totalSize, chunks);
|
||||
|
||||
// add runtime.assert to check chunks == NumResults
|
||||
// add runtime.assert to check floordiv(totalSize + chunkSize - 1,
|
||||
// chunkSize) == NumResults
|
||||
Value cstNumResults = rewriter.create<ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(op.getNumResults()));
|
||||
Value eqOrNot = rewriter.create<AtenEqIntOp>(loc, chunks, cstNumResults);
|
||||
Value realChunks = getIntCeilDiv(rewriter, loc, totalSize, chunkSize);
|
||||
Value eqOrNot =
|
||||
rewriter.create<AtenEqIntOp>(loc, realChunks, cstNumResults);
|
||||
rewriter.create<RuntimeAssertOp>(
|
||||
loc, eqOrNot,
|
||||
rewriter.getStringAttr(
|
||||
|
@ -521,9 +730,15 @@ public:
|
|||
// pattern.add calls go here
|
||||
patterns.add<RecomposeSliceCopy_>(context);
|
||||
patterns.add<RecomposeSelectFill_>(context);
|
||||
patterns.add<RecomposeSplitTensorGetItemOp>(context);
|
||||
|
||||
// TODO: cloud move these patterns to Decompose pass, but should handle
|
||||
// shape and value semantics carefully
|
||||
patterns.add<RecomposeSplitTensorGetItem>(context);
|
||||
patterns.add<RecomposeSplitTensorListUnpack>(context);
|
||||
patterns.add<RecomposeSplitWithSizesGetItem>(context);
|
||||
patterns.add<RecomposeSplitWithSizesListUnpack>(context);
|
||||
patterns.add<RecomposeTensorSplitSectionsGetItem>(context);
|
||||
patterns.add<RecomposeTensorSplitSectionsListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindListUnpack>(context);
|
||||
patterns.add<RecomposeUnbindGetItem>(context);
|
||||
patterns.add<RecomposeChunkListUnpack>(context);
|
||||
|
|
|
@ -21,7 +21,6 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
||||
# these interpolate tests are added specifically to test onnx.Resize.
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
|
@ -817,6 +816,9 @@ FX_IMPORTER_STABLEHLO_CRASHING_SET = {
|
|||
}
|
||||
|
||||
STABLEHLO_PASS_SET = {
|
||||
"SplitWithSizes_Module_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
"AtenLinear1D_basic",
|
||||
"AtenLinear2D_basic",
|
||||
"AtenLinear3DBias_basic",
|
||||
|
@ -1456,6 +1458,8 @@ STABLEHLO_CRASHING_SET = set()
|
|||
# Write the TOSA set as a "passing" set as it is very early in development
|
||||
# and very few tests work yet.
|
||||
TOSA_PASS_SET = {
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
"AtenLinear2D_basic",
|
||||
"AtenLinear3DBias_basic",
|
||||
"ElementwiseAddScalar_NumToTensorFloat_Module_basic",
|
||||
|
@ -2594,6 +2598,10 @@ ONNX_XFAIL_SET = {
|
|||
"_ConvolutionDeprecated2DDeterministicModule_basic",
|
||||
"_SoftmaxModule_basic",
|
||||
# Failure - onnx_import
|
||||
# Failure - onnx_lowering: onnx.SplitToSequence
|
||||
"ChunkListUnpackUneven_Module_basic",
|
||||
"TensorSplitSections_GetItemModule_basic",
|
||||
"TensorSplitSections_ListUnpackModule_basic",
|
||||
# Failure - onnx_lowering: onnx.AveragePool
|
||||
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
|
||||
# these diagonal modules are currently failing due to dynamic shape.
|
||||
|
|
|
@ -969,7 +969,10 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
|
||||
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
|
||||
emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])")
|
||||
emit(
|
||||
"aten::split.sizes : (Tensor, int[], int) -> (Tensor[])", has_canonicalizer=True
|
||||
)
|
||||
emit("aten::tensor_split.sections : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
|
||||
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")
|
||||
|
||||
|
|
|
@ -995,8 +995,8 @@ class ChunkListUnpackUneven_Module(torch.nn.Module):
|
|||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
chunk_0, chunk_1, chunk_2 = torch.chunk(x, 3, 1)
|
||||
return torch.ops.aten.add(chunk_0, chunk_1), chunk_2
|
||||
a0, a1, a2, a3, a4 = torch.chunk(x, 6, 1)
|
||||
return a0, a1, a2, a3, a4
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: ChunkListUnpackUneven_Module())
|
||||
|
@ -1076,3 +1076,48 @@ class SplitWithSizes_Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: SplitWithSizes_Module())
|
||||
def SplitWithSizes_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 2))
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
class TensorSplitSections_GetItemModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
split = torch.tensor_split(x, 3, dim=1)
|
||||
return split[0], split[1], split[2]
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorSplitSections_GetItemModule())
|
||||
def TensorSplitSections_GetItemModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5))
|
||||
|
||||
|
||||
class TensorSplitSections_ListUnpackModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([2, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, x):
|
||||
a, b, c, d = torch.tensor_split(x, 4, dim=1)
|
||||
return a, b, c, d
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: TensorSplitSections_ListUnpackModule())
|
||||
def TensorSplitSections_ListUnpackModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 5))
|
||||
|
|
Loading…
Reference in New Issue