[torch] Add support for `torch.split_with_sizes` via decompose (#2979)

Convert to individiual slices and tuple together as a list.

---------

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
pull/2983/head
Rob Suderman 2024-03-05 15:01:21 -08:00 committed by GitHub
parent 933db87a07
commit bc0527676b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 177 additions and 10 deletions

View File

@ -12910,6 +12910,30 @@ def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [
}];
}
def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::split.sizes : (Tensor, int[], int) -> (Tensor[])`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$split_size,
Torch_IntType:$dim
);
let results = (outs
AnyTorchListOfTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenSplitSizesOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenSplitSizesOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [
AllowsTypeRefinement,
ReadOnly

View File

@ -693,6 +693,131 @@ public:
};
} // namespace
namespace {
class DecomposePrimTolistOp : public OpRewritePattern<PrimTolistOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimTolistOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto self = op.getOperands()[0];
auto selfTy = dyn_cast<Torch::BaseTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "Unknown self shape");
int64_t rank = selfTy.getSizes().size();
if (rank != 1)
return rewriter.notifyMatchFailure(op, "Expected rank-1");
int64_t length = selfTy.getSizes().back();
if (length == Torch::kUnknownSize)
return rewriter.notifyMatchFailure(op, "Tolist length is unknown");
auto resultTy = dyn_cast<Torch::ListType>(op.getType(0));
if (!resultTy)
return rewriter.notifyMatchFailure(op, "Result type is not list");
auto scalarTy = resultTy.getContainedType();
Value zero =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
auto extractTy = rewriter.getType<ValueTensorType>(
llvm::SmallVector<int64_t>{1}, selfTy.getOptionalDtype());
llvm::SmallVector<Value> results;
llvm::SmallVector<int64_t> sizes(selfTy.getSizes());
for (int64_t i = 0; i < length; ++i) {
Value iv =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(i));
Value extract = rewriter.create<AtenSelectIntOp>(
loc, extractTy, self, /*dim=*/zero, /*index=*/iv);
Value scalar = rewriter.create<AtenItemOp>(loc, scalarTy, extract);
results.push_back(scalar);
}
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, results);
return failure();
}
};
} // namespace
namespace {
class DecomposeAtenSplitSizesOp : public OpRewritePattern<AtenSplitSizesOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSplitSizesOp op,
PatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<AtenSplitWithSizesOp>(
op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim());
return success();
}
};
} // namespace
namespace {
class DecomposeAtenSplitWithSizesOp
: public OpRewritePattern<AtenSplitWithSizesOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSplitWithSizesOp op,
PatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Value self = op.getSelf();
SmallVector<Value> splitSizes;
if (!getListConstructElements(op.getSplitSizes(), splitSizes))
return rewriter.notifyMatchFailure(op, "Unable to get sizes");
if (splitSizes.empty())
return rewriter.notifyMatchFailure(op, "No split sizes");
auto selfTy = dyn_cast<BaseTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes())
return rewriter.notifyMatchFailure(op, "Self shape unknown");
int64_t rank = selfTy.getSizes().size();
auto resultTy = dyn_cast<Torch::ListType>(op.getResult().getType());
if (!resultTy)
return rewriter.notifyMatchFailure(op, "Result type not a list");
auto sliceTy =
dyn_cast_or_null<Torch::BaseTensorType>(resultTy.getContainedType());
if (!isa<Torch::BaseTensorType>(sliceTy))
return rewriter.notifyMatchFailure(op, "Slice type is unknown");
int64_t dimInt = 0;
bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt));
if (dimInt < 0)
dimInt += rank;
auto intTy = rewriter.getType<Torch::IntType>();
Value one =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(1));
Value begin =
rewriter.create<ConstantIntOp>(loc, rewriter.getI64IntegerAttr(0));
llvm::SmallVector<Value> slices;
llvm::SmallVector<int64_t> sliceSizes(sliceTy.getSizes());
int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt];
for (auto size : splitSizes) {
Value end = rewriter.create<AtenAddIntOp>(loc, intTy, begin, size);
int64_t sizeInt;
if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) {
sliceSizes[dimInt] = sizeInt;
} else if (hasDim) {
sliceSizes[dimInt] = defaultLength;
}
sliceTy = rewriter.getType<ValueTensorType>(sliceSizes,
sliceTy.getOptionalDtype());
Value slice = rewriter.create<AtenSliceTensorOp>(
loc, sliceTy, op.getSelf(),
/*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one);
slices.push_back(slice);
begin = end;
}
rewriter.replaceOpWithNewOp<PrimListConstructOp>(op, resultTy, slices);
return success();
}
};
} // namespace
namespace {
class DecomposeAtenNarrowOp : public OpRewritePattern<AtenNarrowOp> {
public:
@ -7008,6 +7133,8 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenVarCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenStdCorrectionOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSplitWithSizesOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenNarrowTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenGluOp>(patterns);
@ -7035,6 +7162,7 @@ public:
addPatternIfTargetOpIsIllegal<DecomposeAtenNewEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenEmptyStridedOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenBucketizeTensorOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimTolistOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposePrimsSqueezeOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenMovedimIntOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);

View File

@ -20,7 +20,8 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic"
"IscloseStaticModuleTrue_basic",
"SplitWithSizes_Module_basic",
}
TORCHDYNAMO_XFAIL_SET = {
@ -1478,15 +1479,6 @@ ONNX_XFAIL_SET = {
"VarBiasedModule_basic",
"VarMeanBiasedModule_basic",
# Failure - constant int lowering
"SplitTensorGetItem_Module_basic",
"SplitTensorLastSmallerModule_basic",
"SplitTensorListUnpackModule_basic",
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"UnbindIntGetItem_Module_basic",
"UnbindIntListUnpack_Module_basic",
# Failure - incorrect numerics
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",

View File

@ -741,6 +741,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True)
emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])")
emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])")
emit("aten::unbind.int : (Tensor, int) -> (Tensor[])")
emit("aten::chunk : (Tensor, int, int) -> (Tensor[])")

View File

@ -897,3 +897,25 @@ class ChunkListUnpackUnevenDynamic_Module(torch.nn.Module):
@register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module())
def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 13, 2))
# ==============================================================================
class SplitWithSizes_Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([5, -1, -1], torch.float32, True),
])
def forward(self, x):
split = torch.split(x, [2, 1, 2], dim=0)
return split[0], split[1], split[2]
@register_test_case(module_factory=lambda: SplitWithSizes_Module())
def SplitWithSizes_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 2, 2))