From 9dcde4371999b29535d282819b315a1b1af4f712 Mon Sep 17 00:00:00 2001 From: jinchen62 Date: Wed, 25 Sep 2024 07:56:10 -0700 Subject: [PATCH] Generalize max_unpool lowering --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 33 ++------------ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 7 +-- lib/Conversion/TorchToLinalg/Pooling.cpp | 27 ++++++------ .../Transforms/AbstractInterpLibrary.cpp | 4 +- projects/pt1/e2e_testing/xfail_sets.py | 16 +++---- .../build_tools/abstract_interp_lib_gen.py | 8 ++-- .../build_tools/torch_ods_gen.py | 3 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 16 +++---- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 18 +++++--- test/Conversion/TorchToLinalg/pooling.mlir | 43 +++++++++++++++++++ 10 files changed, 97 insertions(+), 78 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0b1a8b257..ccf64848b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7106,31 +7106,6 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } -def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$indices, - AnyTorchListOfTorchIntType:$output_size - ); - let results = (outs - AnyTorchOptionalTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); - } - void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); - } - }]; -} - def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -7219,12 +7194,12 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ }]; } -def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ +def Torch_AtenMaxUnpoolOp : Torch_Op<"aten.max_unpool", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; + let summary = "Generated op for `aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$indices, @@ -7237,10 +7212,10 @@ def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenMaxUnpoolOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 5, 1); } - void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) { + void AtenMaxUnpoolOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 5, 1); } }]; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 168040d9b..6b2e7d8de 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -3430,11 +3430,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( SmallVector resultShape(resultType.getSizes()); Value resultShapeList = createConstantIntList(binder, rewriter, resultShape); - if (rank == 4) { - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, indices, resultShapeList); - return success(); - } SmallVector padding, strides; if (binder.s64IntegerArrayAttr(padding, "pads", {})) @@ -3469,7 +3464,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( binder.op, resultType, data, indices, resultShapeList, stridesList, paddingList); return success(); diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 90b5b2af7..e3304e9fd 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -596,12 +596,12 @@ namespace { // input_size=2, output_size=5 and stride=2, kernel_size can be either 2 or 3). // What worse, without knowing kernel size we cannot even reliably detect such // cases and this conversion will just return invalid values. -class ConvertAtenMaxUnpool3dOp final - : public OpConversionPattern { +class ConvertAtenMaxUnpoolOp final + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenMaxUnpool3dOp op, OpAdaptor adaptor, + matchAndRewrite(AtenMaxUnpoolOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); @@ -611,21 +611,22 @@ public: Value self = adaptor.getSelf(); auto selfType = cast(self.getType()); - ArrayRef inputSize = selfType.getShape().take_back(3); + size_t spatial = selfType.getRank() - 2; + ArrayRef inputSize = selfType.getShape().take_back(spatial); if (ShapedType::isDynamicShape(inputSize)) return rewriter.notifyMatchFailure(op, "input type must be of static shape"); Value indices = adaptor.getIndices(); auto indicesType = cast(indices.getType()); - if (inputSize != indicesType.getShape().take_back(3)) + if (inputSize != indicesType.getShape().take_back(spatial)) return rewriter.notifyMatchFailure(op, "input/indices shape mismatch"); auto resType = typeConverter->convertType(op.getType()); if (!resType) return rewriter.notifyMatchFailure(op, "invalid result type"); - ArrayRef inferredOutSize = resType.getShape().take_back(3); + ArrayRef inferredOutSize = resType.getShape().take_back(spatial); if (ShapedType::isDynamicShape(inferredOutSize)) return rewriter.notifyMatchFailure(op, "output type must be of static shape"); @@ -636,7 +637,7 @@ public: return rewriter.notifyMatchFailure(op, "only support constant int output"); - if (inferredOutSize != ArrayRef(output)) + if (inferredOutSize != ArrayRef(output).take_back(spatial)) return rewriter.notifyMatchFailure(op, "Invalid output size"); } SmallVector stride; @@ -652,12 +653,12 @@ public: // TODO: add support for asymmetric padding coming from "onnx.MaxUnpool" // (padding.size() == 6). - if (stride.size() != 3 || padding.size() != 3) + if (stride.size() != spatial || padding.size() != spatial) return rewriter.notifyMatchFailure( op, "stride and padding must be of size 3"); int64_t outRank = resType.getRank(); - int64_t NC = outRank - 3; + int64_t NC = outRank - spatial; for (auto &&[inDim, outDim, str, pad] : llvm::zip_equal(inputSize, inferredOutSize, stride, padding)) { @@ -694,7 +695,7 @@ public: // (e.g. pooling_input_size=5, kernel_size=2, stride=2, output_size=2) // pad self and indices tensors to avoid out of bounds access. SmallVector expectedInputShape = - llvm::to_vector(resType.getShape().drop_back(3)); + llvm::to_vector(resType.getShape().drop_back(spatial)); for (auto &&[str, pad, resSize] : llvm::zip_equal(stride, padding, inferredOutSize)) expectedInputShape.emplace_back(ceilDiv(resSize, str) + pad * 2); @@ -707,7 +708,7 @@ public: SmallVector low(outRank, 0); SmallVector high(NC, 0); for (auto &&[inpSize, outSize] : llvm::zip_equal( - inputSize, ArrayRef(expectedInputShape).take_back(3))) { + inputSize, ArrayRef(expectedInputShape).take_back(spatial))) { high.emplace_back(outSize - inpSize); } @@ -1526,8 +1527,8 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( patterns.add>(typeConverter, context); - target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393..a03841398 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8176,7 +8176,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" " return %1 : !torch.tuple, list>\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.max_unpool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_unpool\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: Input and indices must be of the same rank\"\n" " %str_0 = torch.constant.str \"AssertionError: output_size must have 3 elements\"\n" " %none = torch.constant.none\n" @@ -12092,7 +12092,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool3d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_unpool\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index bdb4d7f47..79b14c2ff 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -557,8 +557,8 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ElementwiseRreluTrainStaticModule_basic", "MaxPool1dCeilModeTrueModule_basic", "MaxPool1dStaticCeilModeTrueModule_basic", - "MaxUnpool3dModulePad0_basic", - "MaxUnpool3dModule_basic", + "MaxUnpoolModulePad0_basic", + "MaxUnpoolModule_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", @@ -2837,8 +2837,8 @@ ONNX_XFAIL_SET = { "MaxPool3dWithIndicesNonDefaultDilationModule_basic", "MaxPool3dWithIndicesNonDefaultParamsModule_basic", "MaxPool3dWithIndicesNonDefaultStrideModule_basic", - "MaxUnpool3dModule_basic", - "MaxUnpool3dModulePad0_basic", + "MaxUnpoolModule_basic", + "MaxUnpoolModulePad0_basic", "MeanDimEmptyDimModule_basic", "Mlp1LayerModule_basic", "Mlp2LayerModuleNoBias_basic", @@ -3224,8 +3224,8 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "FakeQuantizePerTensorAffineCachemaskModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", - "MaxUnpool3dModulePad0_basic", - "MaxUnpool3dModule_basic", + "MaxUnpoolModulePad0_basic", + "MaxUnpoolModule_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", @@ -3892,8 +3892,8 @@ ONNX_TOSA_XFAIL_SET = { "FakeQuantizePerTensorAffineCachemaskModule_basic", "IndexPutWithNoneAndBroadcastModule_basic", "MaskedScatterStaticBasic_basic", - "MaxUnpool3dModulePad0_basic", - "MaxUnpool3dModule_basic", + "MaxUnpoolModulePad0_basic", + "MaxUnpoolModule_basic", "MultinomialModule2D_F32", "MultinomialModule2D_basic", "MultinomialModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee..9591d4884 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1056,14 +1056,14 @@ def aten〇max_pool3d_with_indices〡shape(self: List[int], kernel_size: List[in maxpool3d = indices = _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) return maxpool3d, indices -def aten〇max_unpool3d〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: +def aten〇max_unpool〡shape(self: List[int], indices: List[int], output_size: List[int], stride: List[int], padding: List[int]) -> List[int]: assert (len(self) == 5 or len(self) == 4), "Input be of rank 4 or 5" - assert (len(output_size) == 3), "output_size must have 3 elements" + assert (len(output_size) == 3 or len(output_size) == 2), "output_size has 3 or 2 elements" assert (len(self) == len(indices)), "Input and indices must be of the same rank" if len(self) == 5: return [self[0], self[1], output_size[0], output_size[1], output_size[2]] else: - return [self[0], output_size[0], output_size[1], output_size[2]] + return [self[0], self[1], output_size[0], output_size[1]] def aten〇upsample_nearest2d_backward〡shape(grad_output: List[int], output_size: List[int], input_size: List[int], scales_h: Optional[float] = None, scales_w: Optional[float] = None) -> List[int]: return input_size @@ -3179,7 +3179,7 @@ def aten〇max_pool3d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 -def aten〇max_unpool3d〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: +def aten〇max_unpool〡dtype(self_rank_dtype: Tuple[int, int], indices_rank_dtype: Tuple[int, int], output_size: List[int], stride: List[int], padding: List[int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5f53e17b9..a58e37d0a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -618,7 +618,6 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ) emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") - emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, @@ -627,7 +626,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") - emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") + emit("aten::max_unpool : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 4cef7056a..a3c6a67f3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1972,7 +1972,7 @@ def AdaptiveMaxPool3dStaticWithIndices_basic(module, tu: TestUtils): # ============================================================================== -class MaxUnpool3dModule(torch.nn.Module): +class MaxUnpoolModule(torch.nn.Module): def __init__(self): super().__init__() @@ -1985,11 +1985,11 @@ class MaxUnpool3dModule(torch.nn.Module): ] ) def forward(self, x, indices): - return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1)) + return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 1)) -@register_test_case(module_factory=lambda: MaxUnpool3dModule()) -def MaxUnpool3dModule_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: MaxUnpoolModule()) +def MaxUnpoolModule_basic(module, tu: TestUtils): input = tu.rand(2, 2, 4, 5, 6) pool = torch.nn.MaxPool3d( kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 1), return_indices=True @@ -2000,7 +2000,7 @@ def MaxUnpool3dModule_basic(module, tu: TestUtils): # We have a special case for all-zeros padding, test it too. -class MaxUnpool3dModulePad0(torch.nn.Module): +class MaxUnpoolModulePad0(torch.nn.Module): def __init__(self): super().__init__() @@ -2013,11 +2013,11 @@ class MaxUnpool3dModulePad0(torch.nn.Module): ] ) def forward(self, x, indices): - return torch.ops.aten.max_unpool3d(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0)) + return torch.ops.aten.max_unpool(x, indices, (4, 5, 6), (2, 3, 2), (0, 0, 0)) -@register_test_case(module_factory=lambda: MaxUnpool3dModulePad0()) -def MaxUnpool3dModulePad0_basic(module, tu: TestUtils): +@register_test_case(module_factory=lambda: MaxUnpoolModulePad0()) +def MaxUnpoolModulePad0_basic(module, tu: TestUtils): input = tu.rand(2, 2, 4, 5, 6) pool = torch.nn.MaxPool3d( kernel_size=(2, 2, 2), stride=(2, 3, 2), padding=(0, 0, 0), return_indices=True diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 21be2a65f..75ae70747 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1667,14 +1667,20 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc // ----- -// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape -func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_maxunpool_2d_export_without_output_shape +func.func @test_maxunpool_2d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT4:.*]] = torch.constant.int 4 // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4],f32> @@ -1682,8 +1688,8 @@ func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1 // ----- -// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape -func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_maxunpool_3d_export_without_output_shape +func.func @test_maxunpool_3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 // CHECK: %[[INT4:.*]] = torch.constant.int 4 @@ -1698,7 +1704,7 @@ func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1 // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 // CHECK: %[[INT2_2:.*]] = torch.constant.int 2 // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> // return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> return %0 : !torch.vtensor<[1,1,4,4,4],f32> diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 558c50c4f..1a53dfdc6 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -95,3 +95,46 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK: } -> tensor return %4 : !torch.vtensor<[?,?,?,?,?],f32> } + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 floordiv 2, d3 floordiv 2)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @forward_max_unpool2d(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %int1 = torch.constant.int 1 + %int1_0 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int4_1 = torch.constant.int 4 + %0 = torch.prim.ListConstruct %int1, %int1_0, %int4, %int4_1 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %int0 = torch.constant.int 0 + %int0_2 = torch.constant.int 0 + %1 = torch.prim.ListConstruct %int0, %int0_2 : (!torch.int, !torch.int) -> !torch.list + %int2 = torch.constant.int 2 + %int2_3 = torch.constant.int 2 + %2 = torch.prim.ListConstruct %int2, %int2_3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.aten.max_unpool %arg0, %arg1, %0, %2, %1 : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // CHECK: %[[INDICES:.*]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[1,1,2,2],si64> -> tensor<1x1x2x2xi64> + // CHECK: %[[INPUT:.*]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[1,1,2,2],f32> -> tensor<1x1x2x2xf32> + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[DIM0:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32> + // CHECK: %[[C1:.*]] = arith.constant 1 : index + // CHECK: %[[DIM1:.*]] = tensor.dim %[[INPUT]], %[[C1]] : tensor<1x1x2x2xf32> + // CHECK: %[[SHAPE:.*]] = tensor.empty(%[[DIM0]], %[[DIM1]]) : tensor + // CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [#map, #map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[INPUT]], %[[INDICES]] : tensor<1x1x2x2xf32>, tensor<1x1x2x2xi64>) outs(%[[SHAPE]] : tensor) { + // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[CURRENT_INDEX:.*]]: i64, %[[OUT:.*]]: f32): + // CHECK-NEXT: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK-NEXT: %[[INDEX_CAST:.*]] = arith.index_cast %[[CURRENT_INDEX:.*]] : i64 to index + // CHECK-NEXT: %[[INDEX2:.*]] = linalg.index 2 : index + // CHECK-NEXT: %[[INDEX3:.*]] = linalg.index 3 : index + // CHECK-NEXT: %[[C4:.*]] = arith.constant 4 : index + // CHECK-NEXT: %[[MULI:.*]] = arith.muli %[[INDEX2:.*]], %[[C4:.*]] : index + // CHECK-NEXT: %[[ADDI:.*]] = arith.addi %[[MULI:.*]], %[[INDEX3:.*]] : index + // CHECK-NEXT: %[[CMPI:.*]] = arith.cmpi eq, %[[INDEX_CAST:.*]], %[[ADDI:.*]] : index + // CHECK-NEXT: %[[SELECT:.*]] = arith.select %[[CMPI:.*]], %[[CURRENT_VALUE:.*]], %[[CST:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[SELECT:.*]] : f32 + // CHECK: } -> tensor + // CHECK: %[[CAST:.*]] = tensor.cast %[[GENERIC]] : tensor to tensor<1x1x4x4xf32> + // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CAST]] : tensor<1x1x4x4xf32> -> !torch.vtensor<[1,1,4,4],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> + return %3 : !torch.vtensor<[1,1,4,4],f32> +}