Generalize max_unpool lowering

jinchen62 2024-09-25 07:56:10 -07:00
parent 67732883fa
commit 9dcde43719
10 changed files with 97 additions and 78 deletions

View File

@ -7106,31 +7106,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
}]; }];
} }
def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$indices,
AnyTorchListOfTorchIntType:$output_size
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 3, 1);
}
void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 3, 1);
}
}];
}
def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
@ -7219,12 +7194,12 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [
}]; }];
} }
def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ def Torch_AtenMaxUnpoolOp : Torch_Op<"aten.max_unpool", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,
ReadOnly ReadOnly
]> { ]> {
let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; let summary = "Generated op for `aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`";
let arguments = (ins let arguments = (ins
AnyTorchTensorType:$self, AnyTorchTensorType:$self,
AnyTorchTensorType:$indices, AnyTorchTensorType:$indices,
@ -7237,10 +7212,10 @@ def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [
); );
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{ let extraClassDefinition = [{
ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) { ParseResult AtenMaxUnpoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1); return parseDefaultTorchOp(parser, result, 5, 1);
} }
void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) { void AtenMaxUnpoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1); printDefaultTorchOp(printer, *this, 5, 1);
} }
}]; }];

View File

@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
SmallVector<int64_t> resultShape(resultType.getSizes()); SmallVector<int64_t> resultShape(resultType.getSizes());
Value resultShapeList = Value resultShapeList =
createConstantIntList(binder, rewriter, resultShape); createConstantIntList(binder, rewriter, resultShape);
if (rank == 4) {
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool2dOp>(
binder.op, resultType, data, indices, resultShapeList);
return success();
}
SmallVector<int64_t> padding, strides; SmallVector<int64_t> padding, strides;
if (binder.s64IntegerArrayAttr(padding, "pads", {})) if (binder.s64IntegerArrayAttr(padding, "pads", {}))
@ -3469,7 +3464,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value paddingList = createConstantIntList(binder, rewriter, padding); Value paddingList = createConstantIntList(binder, rewriter, padding);
Value stridesList = createConstantIntList(binder, rewriter, strides); Value stridesList = createConstantIntList(binder, rewriter, strides);
rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpool3dOp>( rewriter.replaceOpWithNewOp<Torch::AtenMaxUnpoolOp>(
binder.op, resultType, data, indices, resultShapeList, stridesList, binder.op, resultType, data, indices, resultShapeList, stridesList,
paddingList); paddingList);
return success(); return success();

View File

@ -596,12 +596,12 @@ namespace {
// input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3). // input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3).
// What worse, without knowing kernel size we cannot even reliably detect such // What worse, without knowing kernel size we cannot even reliably detect such
// cases and this conversion will just return invalid values. // cases and this conversion will just return invalid values.
class ConvertAtenMaxUnpool3dOp final class ConvertAtenMaxUnpoolOp final
: public OpConversionPattern<AtenMaxUnpool3dOp> { : public OpConversionPattern<AtenMaxUnpoolOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor, matchAndRewrite(AtenMaxUnpoolOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter))) if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure(); return failure();
@ -611,21 +611,22 @@ public:
Value self = adaptor.getSelf(); Value self = adaptor.getSelf();
auto selfType = cast<RankedTensorType>(self.getType()); auto selfType = cast<RankedTensorType>(self.getType());
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(3); size_t spatial = selfType.getRank() - 2;
ArrayRef<int64_t> inputSize = selfType.getShape().take_back(spatial);
if (ShapedType::isDynamicShape(inputSize)) if (ShapedType::isDynamicShape(inputSize))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"input type must be of static shape"); "input type must be of static shape");
Value indices = adaptor.getIndices(); Value indices = adaptor.getIndices();
auto indicesType = cast<RankedTensorType>(indices.getType()); auto indicesType = cast<RankedTensorType>(indices.getType());
if (inputSize != indicesType.getShape().take_back(3)) if (inputSize != indicesType.getShape().take_back(spatial))
return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); return rewriter.notifyMatchFailure(op, "input/indices shape mismatch");
auto resType = typeConverter->convertType<RankedTensorType>(op.getType()); auto resType = typeConverter->convertType<RankedTensorType>(op.getType());
if (!resType) if (!resType)
return rewriter.notifyMatchFailure(op, "invalid result type"); return rewriter.notifyMatchFailure(op, "invalid result type");
ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(3); ArrayRef<int64_t> inferredOutSize = resType.getShape().take_back(spatial);
if (ShapedType::isDynamicShape(inferredOutSize)) if (ShapedType::isDynamicShape(inferredOutSize))
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"output type must be of static shape"); "output type must be of static shape");
@ -636,7 +637,7 @@ public:
return rewriter.notifyMatchFailure(op, return rewriter.notifyMatchFailure(op,
"only support constant int output"); "only support constant int output");
if (inferredOutSize != ArrayRef(output)) if (inferredOutSize != ArrayRef(output).take_back(spatial))
return rewriter.notifyMatchFailure(op, "Invalid output size"); return rewriter.notifyMatchFailure(op, "Invalid output size");
} }
SmallVector<int64_t> stride; SmallVector<int64_t> stride;
@ -652,12 +653,12 @@ public:
// TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool"
// (padding.size() == 6). // (padding.size() == 6).
if (stride.size() != 3 || padding.size() != 3) if (stride.size() != spatial || padding.size() != spatial)
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
op, "stride and padding must be of size 3"); op, "stride and padding must be of size 3");
int64_t outRank = resType.getRank(); int64_t outRank = resType.getRank();
int64_t NC = outRank - 3; int64_t NC = outRank - spatial;
for (auto &&[inDim, outDim, str, pad] : for (auto &&[inDim, outDim, str, pad] :
llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) {
@ -694,7 +695,7 @@ public:
// (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2)
// pad self and indices tensors to avoid out of bounds access. // pad self and indices tensors to avoid out of bounds access.
SmallVector<int64_t> expectedInputShape = SmallVector<int64_t> expectedInputShape =
llvm::to_vector(resType.getShape().drop_back(3)); llvm::to_vector(resType.getShape().drop_back(spatial));
for (auto &&[str, pad, resSize] : for (auto &&[str, pad, resSize] :
llvm::zip_equal(stride, padding, inferredOutSize)) llvm::zip_equal(stride, padding, inferredOutSize))
expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2);
@ -707,7 +708,7 @@ public:
SmallVector<int64_t> low(outRank, 0); SmallVector<int64_t> low(outRank, 0);
SmallVector<int64_t> high(NC, 0); SmallVector<int64_t> high(NC, 0);
for (auto &&[inpSize, outSize] : llvm::zip_equal( for (auto &&[inpSize, outSize] : llvm::zip_equal(
inputSize, ArrayRef(expectedInputShape).take_back(3))) { inputSize, ArrayRef(expectedInputShape).take_back(spatial))) {
high.emplace_back(outSize - inpSize); high.emplace_back(outSize - inpSize);
} }
@ -1526,8 +1527,8 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality(
patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter, patterns.add<ConvertAtenMaxPoolOp<AtenMaxPool3dWithIndicesOp>>(typeConverter,
context); context);
target.addIllegalOp<AtenMaxUnpool3dOp>(); target.addIllegalOp<AtenMaxUnpoolOp>();
patterns.add<ConvertAtenMaxUnpool3dOp>(typeConverter, context); patterns.add<ConvertAtenMaxUnpoolOp>(typeConverter, context);
target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>(); target.addIllegalOp<AtenAvgPool1dOp, AtenAvgPool2dOp, AtenAvgPool3dOp>();
patterns patterns

View File

@ -8176,7 +8176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n" " %1 = torch.prim.TupleConstruct %0, %0 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %1 : !torch.tuple<list<int>, list<int>>\n" " return %1 : !torch.tuple<list<int>, list<int>>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.max_unpool\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n" " %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n"
" %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n" " %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
@ -12092,7 +12092,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n" " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n" " return %1 : !torch.tuple<int, int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_unpool\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"

View File

@ -557,8 +557,8 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = {
"ElementwiseRreluTrainStaticModule_basic", "ElementwiseRreluTrainStaticModule_basic",
"MaxPool1dCeilModeTrueModule_basic", "MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpoolModulePad0_basic",
"MaxUnpool3dModule_basic", "MaxUnpoolModule_basic",
"MultinomialModule2D_F32", "MultinomialModule2D_F32",
"MultinomialModule2D_basic", "MultinomialModule2D_basic",
"MultinomialModule_basic", "MultinomialModule_basic",
@ -2837,8 +2837,8 @@ ONNX_XFAIL_SET = {
"MaxPool3dWithIndicesNonDefaultDilationModule_basic", "MaxPool3dWithIndicesNonDefaultDilationModule_basic",
"MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultParamsModule_basic",
"MaxPool3dWithIndicesNonDefaultStrideModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic",
"MaxUnpool3dModule_basic", "MaxUnpoolModule_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpoolModulePad0_basic",
"MeanDimEmptyDimModule_basic", "MeanDimEmptyDimModule_basic",
"Mlp1LayerModule_basic", "Mlp1LayerModule_basic",
"Mlp2LayerModuleNoBias_basic", "Mlp2LayerModuleNoBias_basic",
@ -3224,8 +3224,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"FakeQuantizePerTensorAffineCachemaskModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic", "IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpoolModulePad0_basic",
"MaxUnpool3dModule_basic", "MaxUnpoolModule_basic",
"MultinomialModule2D_F32", "MultinomialModule2D_F32",
"MultinomialModule2D_basic", "MultinomialModule2D_basic",
"MultinomialModule_basic", "MultinomialModule_basic",
@ -3892,8 +3892,8 @@ ONNX_TOSA_XFAIL_SET = {
"FakeQuantizePerTensorAffineCachemaskModule_basic", "FakeQuantizePerTensorAffineCachemaskModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic", "IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic", "MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic", "MaxUnpoolModulePad0_basic",
"MaxUnpool3dModule_basic", "MaxUnpoolModule_basic",
"MultinomialModule2D_F32", "MultinomialModule2D_F32",
"MultinomialModule2D_basic", "MultinomialModule2D_basic",
"MultinomialModule_basic", "MultinomialModule_basic",

View File

@ -1056,14 +1056,14 @@ def atenmax_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in
maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode)
return maxpool3d, indices return maxpool3d, indices
def atenmax_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: def atenmax_unpool〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]:
assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5" assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5"
assert (len(output_size) == 3), "output_size must have 3 elements" assert (len(output_size) == 3 or len(output_size) == 2), "output_size has 3 or 2 elements"
assert (len(self) == len(indices)), "Input and indices must be of the same rank" assert (len(self) == len(indices)), "Input and indices must be of the same rank"
if len(self) == 5: if len(self) == 5:
return [self[0], self[1], output_size[0], output_size[1], output_size[2]] return [self[0], self[1], output_size[0], output_size[1], output_size[2]]
else: else:
return [self[0], output_size[0], output_size[1], output_size[2]] return [self[0], self[1], output_size[0], output_size[1]]
def atenupsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: def atenupsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]:
return input_size return input_size
@ -3179,7 +3179,7 @@ def atenmax_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype, torch.int64 return self_dtype, torch.int64
def atenmax_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: def atenmax_unpool〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int:
self_rank, self_dtype = self_rank_dtype self_rank, self_dtype = self_rank_dtype
return self_dtype return self_dtype

View File

@ -618,7 +618,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
) )
emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)")
emit( emit(
"aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)",
has_canonicalizer=True, has_canonicalizer=True,
@ -627,7 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)"
) )
emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)")
emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit("aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)")
emit( emit(
"aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)"
) )

View File

@ -1972,7 +1972,7 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils):
# ============================================================================== # ==============================================================================
class MaxUnpool3dModule(torch.nn.Module): class MaxUnpoolModule(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -1985,11 +1985,11 @@ class MaxUnpool3dModule(torch.nn.Module):
] ]
) )
def forward(self, x, indices): def forward(self, x, indices):
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1)) return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1))
@register_test_case(module_factory=lambda: MaxUnpool3dModule()) @register_test_case(module_factory=lambda: MaxUnpoolModule())
def MaxUnpool3dModule_basic(module, tu: TestUtils): def MaxUnpoolModule_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 5, 6) input = tu.rand(2, 2, 4, 5, 6)
pool = torch.nn.MaxPool3d( pool = torch.nn.MaxPool3d(
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True
@ -2000,7 +2000,7 @@ def MaxUnpool3dModule_basic(module, tu: TestUtils):
# We have a special case for all-zeros padding, test it too. # We have a special case for all-zeros padding, test it too.
class MaxUnpool3dModulePad0(torch.nn.Module): class MaxUnpoolModulePad0(torch.nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
@ -2013,11 +2013,11 @@ class MaxUnpool3dModulePad0(torch.nn.Module):
] ]
) )
def forward(self, x, indices): def forward(self, x, indices):
return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0)) return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0))
@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0()) @register_test_case(module_factory=lambda: MaxUnpoolModulePad0())
def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): def MaxUnpoolModulePad0_basic(module, tu: TestUtils):
input = tu.rand(2, 2, 4, 5, 6) input = tu.rand(2, 2, 4, 5, 6)
pool = torch.nn.MaxPool3d( pool = torch.nn.MaxPool3d(
kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True

View File

@ -1667,14 +1667,20 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc
// ----- // -----
// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape // CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape
func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT4:.*]] = torch.constant.int 4
// CHECK: %[[INT4_0:.*]] = torch.constant.int 4 // CHECK: %[[INT4_0:.*]] = torch.constant.int 4
// CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32> // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_1:.*]] = torch.constant.int 0
// CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT2:.*]] = torch.constant.int 2
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32>
return %0 : !torch.vtensor<[1,1,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4],f32>
@ -1682,8 +1688,8 @@ func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1
// ----- // -----
// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape // CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape
func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1:.*]] = torch.constant.int 1
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT4:.*]] = torch.constant.int 4
@ -1698,7 +1704,7 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1
// CHECK: %[[INT2_1:.*]] = torch.constant.int 2 // CHECK: %[[INT2_1:.*]] = torch.constant.int 2
// CHECK: %[[INT2_2:.*]] = torch.constant.int 2 // CHECK: %[[INT2_2:.*]] = torch.constant.int 2
// CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int> // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32> // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4,4],f32>
// return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32>
%0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32>
return %0 : !torch.vtensor<[1,1,4,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4,4],f32>

View File

@ -95,3 +95,46 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.
// CHECK: } -> tensor<?x?x?x?x?xf32> // CHECK: } -> tensor<?x?x?x?x?xf32>
return %4 : !torch.vtensor<[?,?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?,?],f32>
} }
// -----
// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)>
// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%int1 = torch.constant.int 1
%int1_0 = torch.constant.int 1
%int4 = torch.constant.int 4
%int4_1 = torch.constant.int 4
%0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
%int0 = torch.constant.int 0
%int0_2 = torch.constant.int 0
%1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list<int>
%int2 = torch.constant.int 2
%int2_3 = torch.constant.int 2
%2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list<int>
%3 = torch.aten.max_unpool %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list<int>, !torch.list<int>, !torch.list<int> -> !torch.vtensor<[1,1,4,4],f32>
// CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64>
// CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32>
// CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor<?x?x4x4xf32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor<?x?x4x4xf32>) {
// CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32):
// CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
// CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index
// CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index
// CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index
// CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index
// CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index
// CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index
// CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index
// CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32
// CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32
// CHECK: } -> tensor<?x?x4x4xf32>
// CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor<?x?x4x4xf32> to tensor<1x1x4x4xf32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x1x4x4xf32> -> !torch.vtensor<[1,1,4,4],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32>
return %3 : !torch.vtensor<[1,1,4,4],f32>
}