mirror of https://github.com/llvm/torch-mlir
[torch] Add support for `torch.split_with_sizes` via decompose (#2979)
Convert to individiual slices and tuple together as a list. --------- Co-authored-by: Scott Todd <scott.todd0@gmail.com>pull/2983/head
parent
933db87a07
commit
bc0527676b
|
@ -12910,6 +12910,30 @@ def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `aten::split.sizes : (Tensor, int[], int) -> (Tensor[])`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchListOfTorchIntType:$split_size,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchListOfTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult AtenSplitSizesOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 3, 1);
|
||||
}
|
||||
void AtenSplitSizesOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 3, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
|
||||
AllowsTypeRefinement,
|
||||
ReadOnly
|
||||
|
|
|
@ -693,6 +693,131 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(PrimTolistOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto self = op.getOperands()[0];
|
||||
auto selfTy = dyn_cast<Torch::BaseTensorType>(self.getType());
|
||||
if (!selfTy || !selfTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "Unknown self shape");
|
||||
|
||||
int64_t rank = selfTy.getSizes().size();
|
||||
if (rank != 1)
|
||||
return rewriter.notifyMatchFailure(op, "Expected rank-1");
|
||||
|
||||
int64_t length = selfTy.getSizes().back();
|
||||
if (length == Torch::kUnknownSize)
|
||||
return rewriter.notifyMatchFailure(op, "Tolist length is unknown");
|
||||
|
||||
auto resultTy = dyn_cast<Torch::ListType>(op.getType(0));
|
||||
if (!resultTy)
|
||||
return rewriter.notifyMatchFailure(op, "Result type is not list");
|
||||
|
||||
auto scalarTy = resultTy.getContainedType();
|
||||
Value zero =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
auto extractTy = rewriter.getType<ValueTensorType>(
|
||||
llvm::SmallVector<int64_t>{1}, selfTy.getOptionalDtype());
|
||||
llvm::SmallVector<Value> results;
|
||||
llvm::SmallVector<int64_t> sizes(selfTy.getSizes());
|
||||
for (int64_t i = 0; i < length; ++i) {
|
||||
Value iv =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
|
||||
Value extract = rewriter.create<AtenSelectIntOp>(
|
||||
loc, extractTy, self, /*dim=*/zero, /*index=*/iv);
|
||||
Value scalar = rewriter.create<AtenItemOp>(loc, scalarTy, extract);
|
||||
results.push_back(scalar);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, results);
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
|
||||
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenSplitWithSizesOp
|
||||
: public OpRewritePattern<AtenSplitWithSizesOp> {
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(AtenSplitWithSizesOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
Value self = op.getSelf();
|
||||
SmallVector<Value> splitSizes;
|
||||
if (!getListConstructElements(op.getSplitSizes(), splitSizes))
|
||||
return rewriter.notifyMatchFailure(op, "Unable to get sizes");
|
||||
|
||||
if (splitSizes.empty())
|
||||
return rewriter.notifyMatchFailure(op, "No split sizes");
|
||||
|
||||
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
|
||||
if (!selfTy || !selfTy.hasSizes())
|
||||
return rewriter.notifyMatchFailure(op, "Self shape unknown");
|
||||
|
||||
int64_t rank = selfTy.getSizes().size();
|
||||
auto resultTy = dyn_cast<Torch::ListType>(op.getResult().getType());
|
||||
if (!resultTy)
|
||||
return rewriter.notifyMatchFailure(op, "Result type not a list");
|
||||
|
||||
auto sliceTy =
|
||||
dyn_cast_or_null<Torch::BaseTensorType>(resultTy.getContainedType());
|
||||
if (!isa<Torch::BaseTensorType>(sliceTy))
|
||||
return rewriter.notifyMatchFailure(op, "Slice type is unknown");
|
||||
|
||||
int64_t dimInt = 0;
|
||||
bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt));
|
||||
if (dimInt < 0)
|
||||
dimInt += rank;
|
||||
|
||||
auto intTy = rewriter.getType<Torch::IntType>();
|
||||
Value one =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
|
||||
Value begin =
|
||||
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
|
||||
|
||||
llvm::SmallVector<Value> slices;
|
||||
llvm::SmallVector<int64_t> sliceSizes(sliceTy.getSizes());
|
||||
int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt];
|
||||
for (auto size : splitSizes) {
|
||||
Value end = rewriter.create<AtenAddIntOp>(loc, intTy, begin, size);
|
||||
|
||||
int64_t sizeInt;
|
||||
if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) {
|
||||
sliceSizes[dimInt] = sizeInt;
|
||||
} else if (hasDim) {
|
||||
sliceSizes[dimInt] = defaultLength;
|
||||
}
|
||||
|
||||
sliceTy = rewriter.getType<ValueTensorType>(sliceSizes,
|
||||
sliceTy.getOptionalDtype());
|
||||
Value slice = rewriter.create<AtenSliceTensorOp>(
|
||||
loc, sliceTy, op.getSelf(),
|
||||
/*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one);
|
||||
slices.push_back(slice);
|
||||
begin = end;
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, slices);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class DecomposeAtenNarrowOp : public OpRewritePattern<AtenNarrowOp> {
|
||||
public:
|
||||
|
@ -7008,6 +7133,8 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
|
||||
|
@ -7035,6 +7162,7 @@ public:
|
|||
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
|
||||
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
|
||||
|
|
|
@ -20,7 +20,8 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
|
||||
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
|
||||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic"
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
}
|
||||
|
||||
TORCHDYNAMO_XFAIL_SET = {
|
||||
|
@ -1478,15 +1479,6 @@ ONNX_XFAIL_SET = {
|
|||
"VarBiasedModule_basic",
|
||||
"VarMeanBiasedModule_basic",
|
||||
|
||||
# Failure - constant int lowering
|
||||
"SplitTensorGetItem_Module_basic",
|
||||
"SplitTensorLastSmallerModule_basic",
|
||||
"SplitTensorListUnpackModule_basic",
|
||||
"SplitTensorNegativeDimModule_basic",
|
||||
"SplitWithSizesListUnpackModule_basic",
|
||||
"UnbindIntGetItem_Module_basic",
|
||||
"UnbindIntListUnpack_Module_basic",
|
||||
|
||||
# Failure - incorrect numerics
|
||||
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
|
||||
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
|
||||
|
|
|
@ -741,6 +741,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
|
||||
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
|
||||
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
|
||||
emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])")
|
||||
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
|
||||
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")
|
||||
|
||||
|
|
|
@ -897,3 +897,25 @@ class ChunkListUnpackUnevenDynamic_Module(torch.nn.Module):
|
|||
@register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module())
|
||||
def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 13, 2))
|
||||
|
||||
# ==============================================================================
|
||||
|
||||
class SplitWithSizes_Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([5, -1, -1], torch.float32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
split = torch.split(x, [2, 1, 2], dim=0)
|
||||
return split[0], split[1], split[2]
|
||||
|
||||
@register_test_case(module_factory=lambda: SplitWithSizes_Module())
|
||||
def SplitWithSizes_Module_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(5, 2, 2))
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue