From 35dd8c52cd23d74cc495ccf314b1101d38cd6512 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 4 Jun 2024 21:09:53 +0530 Subject: [PATCH] [ONNX] Add OnnxToTorch Lowering for MaxUnpool op (#3413) This commit also adds the Torch declaration for aten.max_unpool2d and aten.max_unpool3d op. The TorchToLinalg lowering for the same will be added in a follow-up commit. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 52 +++++++++++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 78 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 + .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 39 ++++++++++ 4 files changed, 171 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c0cac1f1f..559122f98 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6819,6 +6819,31 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ }]; } +def Torch_AtenMaxUnpool2dOp : Torch_Op<"aten.max_unpool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenMaxUnpool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, @@ -6907,6 +6932,33 @@ def Torch_AtenMaxPool3dOp : Torch_Op<"aten.max_pool3d", [ }]; } +def Torch_AtenMaxUnpool3dOp : Torch_Op<"aten.max_unpool3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$indices, + AnyTorchListOfTorchIntType:$output_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMaxUnpool3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenMaxUnpool3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenMaxPool3dWithIndicesOp : Torch_Op<"aten.max_pool3d_with_indices", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 9f5b704a1..c7e41a7a0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1926,4 +1926,82 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); + patterns.onOp( + "MaxUnpool", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // TODO: Add support for `output_shape` arg. + if (binder.op->getNumOperands() == 3) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: output_shape arg is not supported"); + + Torch::ValueTensorType resultType; + Value data, indices; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "data/indices/resultType bind failure"); + std::optional maybeRank = Torch::getTensorRank(data); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + int64_t rank = *maybeRank; + int64_t spatial = rank - 2; + + if (rank <= 3 || rank > 5) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: MaxUnpool support " + "only present for rank 4/5 input"); + + if (!(resultType.hasSizes() && resultType.areAllSizesKnown())) + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: expected result to have all shapes " + "statically known"); + + SmallVector resultShape(resultType.getSizes()); + Value resultShapeList = + createConstantIntList(binder, rewriter, resultShape); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList); + return success(); + } + + SmallVector padding, strides; + if (binder.s64IntegerArrayAttr(padding, "pads", {})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (!padding.empty() && + padding.size() != static_cast(2 * spatial)) + return rewriter.notifyMatchFailure( + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); + if (binder.s64IntegerArrayAttr(strides, "strides", {})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (!strides.empty() && strides.size() != static_cast(spatial)) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + + if (padding.empty()) + padding.resize(spatial, 0); + if (strides.empty()) + strides.resize(spatial, 1); + + // If the padding is symmetric we can push the padding + // operation to the torch operator. + if (padding.size() == static_cast(2 * spatial)) { + bool equal = true; + for (int i = 0; i < spatial; ++i) { + equal = equal && (padding[i] == padding[i + spatial]); + } + if (equal) + padding.resize(spatial); + } + + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, indices, resultShapeList, stridesList, + paddingList); + return success(); + }); } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5cce514d4..7734f7ad2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -597,6 +597,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): ) emit("aten::max_pool1d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") emit("aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool2d : (Tensor, Tensor, int[]) -> (Tensor)") emit( "aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)", has_canonicalizer=True, @@ -605,6 +606,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): "aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)" ) emit("aten::max_pool3d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)") + emit("aten::max_unpool3d : (Tensor, Tensor, int[], int[], int[]) -> (Tensor)") emit( "aten::max_pool3d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)" ) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 865648c40..227eac7d9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1087,3 +1087,42 @@ func.func @test_lpnormalization(%arg0: !torch.vtensor<[3,4,5,6,7],f32>) -> !torc %0 = torch.operator "onnx.LpNormalization"(%arg0) {torch.onnx.axis = 2 : si64, torch.onnx.p = 2 : si64} : (!torch.vtensor<[3,4,5,6,7],f32>) -> !torch.vtensor<[3,4,1,6,7],f32> return %0 : !torch.vtensor<[3,4,1,6,7],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool_export_without_output_shape +func.func @test_maxunpool_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool2d %arg0, %arg1, %[[OUTPUT_SHAPE]] : !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>, !torch.list -> !torch.vtensor<[1,1,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,1,2,2],si64>) -> !torch.vtensor<[1,1,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxunpool3d_export_without_output_shape +func.func @test_maxunpool3d_export_without_output_shape(%arg0: !torch.vtensor<[1,1,2,2,2],f32>, %arg1: !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_0:.*]] = torch.constant.int 4 + // CHECK: %[[INT4_1:.*]] = torch.constant.int 4 + // CHECK: %[[OUTPUT_SHAPE:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_0]], %[[INT4]], %[[INT4_0]], %[[INT4_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]], %[[INT0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_2:.*]] = torch.constant.int 2 + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_1]], %[[INT2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.max_unpool3d %arg0, %arg1, %[[OUTPUT_SHAPE]], %[[STRIDE]], %[[PADDING]] : !torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>, !torch.list, !torch.list, !torch.list -> !torch.vtensor<[1,1,4,4,4],f32> + // return %[[RESULT]] : !torch.vtensor<[1,1,4,4,4],f32> + %0 = torch.operator "onnx.MaxUnpool"(%arg0, %arg1) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,2,2,2],f32>, !torch.vtensor<[1,1,2,2,2],si64>) -> !torch.vtensor<[1,1,4,4,4],f32> + return %0 : !torch.vtensor<[1,1,4,4,4],f32> +}