From 5437f32193888f4cae3b4ae02123bd8828564ad6 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:52:15 -0800 Subject: [PATCH] [onnx][torch] Lower `onnx.grid_sampler` to the `torch` equivalents (#2952) This is the lowering of gridsampler from onnx to torch using our prior implementation of AtenGridSamplerOp. Here are several checks for cornercases implemented. We may decide to have part of these checks in AtenGridSamplerOp instead of the onnx lowering portion. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 67 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 12 ++++ 2 files changed, 79 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 12b7ab559..4d1aaf42d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -92,6 +92,73 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( operand, vApproximate); return success(); }); + patterns.onOp( + "GridSample", 20, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + Value grid; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(grid, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "operand grid_sampler bind failure"); + + auto inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + uint32_t inputRank = inputShape.size(); + auto gridTensorType = grid.getType().cast(); + ArrayRef gridShape = gridTensorType.getSizes(); + uint32_t gridRank = gridShape.size(); + + if (inputRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only input rank 4 supported"); + if (gridRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only grid rank 4 supported"); + if (inputShape[0] != gridShape[0]) + return rewriter.notifyMatchFailure( + binder.op, "N must be same for input and grid"); + if (gridShape[3] != 2) + return rewriter.notifyMatchFailure(binder.op, + "gridShape[3] expected to be 2"); + std::string mode; + if (binder.customOpNameStringAttr(mode, "mode", "bilinear")) + return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); + if (mode != "bilinear") + return rewriter.notifyMatchFailure( + binder.op, "currently only mode : bilinear supported"); + std::string padding; + if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) + return rewriter.notifyMatchFailure(binder.op, + "padding_mode bind failure"); + if (padding != "zeros") + return rewriter.notifyMatchFailure( + binder.op, "currently only padding_mode : zeros supported"); + int64_t align; + if (binder.s64IntegerAttr(align, "align_corners", 0)) + return rewriter.notifyMatchFailure(binder.op, + "align_corners bind failure"); + if (align != 0) + return rewriter.notifyMatchFailure( + binder.op, "currently only align_corners : 0 supported"); + + Value interpolationMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value paddingMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value alignCorners = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, grid, interpolationMode, paddingMode, + alignCorners); + return success(); + }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 8729e7f2d..d5a47aba3 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -395,6 +395,18 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- +// CHECK-LABEL: @test_grid_sampler +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[B0:.*]] = torch.constant.bool false +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_less_or_equal func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>