From b66eabd492afb0d5cd1f778d82a8e5d91cc210a9 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Wed, 17 Apr 2024 13:38:19 -0700 Subject: [PATCH] [onnx][torch][linalg] Implementing align-corner modes for gridsampler (#3171) Align corner modes which select what the corners mean. Either the center of the corner points or the edges of the edge points. --------- Co-authored-by: Rob Suderman --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 7 +- .../TorchToLinalg/Uncategorized.cpp | 75 ++++++++++++++----- .../test_suite/gridsampler.py | 33 +++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 20 ++++- .../Conversion/TorchToLinalg/gridsampler.mlir | 31 +++++++- 5 files changed, 136 insertions(+), 30 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 8d6b38814..92ef81390 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -140,9 +140,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, "align_corners bind failure"); - if (align != 1) - return rewriter.notifyMatchFailure( - binder.op, "currently only align_corners = 1 supported"); Value interpolationMode = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -150,9 +147,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + bool alignMode = align; Value alignCorners = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getBoolAttr(false)); + rewriter.getBoolAttr(alignMode)); rewriter.replaceOpWithNewOp( binder.op, resultType, input, grid, interpolationMode, paddingMode, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index bf25786c3..441c76ce7 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2565,6 +2565,7 @@ public: resultSize.push_back(rewriter.create(loc, grid, 1)); if (resultType.isDynamicDim(3)) resultSize.push_back(rewriter.create(loc, grid, 2)); + Value alignCorners = adaptor.getAlignCorners(); Value resultFinal = rewriter.create(loc, resultType, resultSize); auto sGrid = rewriter.create( @@ -2573,30 +2574,56 @@ public: [&](OpBuilder &b, Location loc, ValueRange args) { Value gr0 = args[1]; Value gr1 = args[0]; + Value gr0Half = b.create(loc, gr0, twoFloat); + Value gr1Half = b.create(loc, gr1, twoFloat); + Value gr0HalfSelect = + b.create(loc, alignCorners, zeroFloat, gr0Half); + Value gr1HalfSelect = + b.create(loc, alignCorners, zeroFloat, gr1Half); Value gplus0 = b.create(loc, gr0, oneFloat); Value gplus1 = b.create(loc, gr1, oneFloat); - Value result0 = b.create(loc, gplus0, innerDim0e); - Value result1 = b.create(loc, gplus1, innerDim1e); - Value lower0 = b.create(loc, int64type, result0); - Value lower1 = b.create(loc, int64type, result1); + Value gPlusMul0 = b.create(loc, gplus0, innerDim0e); + Value gPlusMul1 = b.create(loc, gplus1, innerDim1e); + Value result0 = + b.create(loc, gPlusMul0, gr0HalfSelect); + Value result1 = + b.create(loc, gPlusMul1, gr1HalfSelect); + Value checkLowerBound0 = b.create( + loc, arith::CmpFPredicate::OLT, result0, zeroFloat); + Value checkLowerBound1 = b.create( + loc, arith::CmpFPredicate::OLT, result1, zeroFloat); + Value lowerOrig0 = b.create(loc, int64type, result0); + Value lowerOrig1 = b.create(loc, int64type, result1); + Value zeroInt = + b.create(loc, b.getIntegerAttr(int64type, 0)); Value oneInt = b.create(loc, b.getIntegerAttr(int64type, 1)); + Value lowerSub0 = b.create(loc, lowerOrig0, oneInt); + Value lowerSub1 = b.create(loc, lowerOrig1, oneInt); + Value lower0 = b.create(loc, checkLowerBound0, + lowerSub0, lowerOrig0); + Value lower1 = b.create(loc, checkLowerBound1, + lowerSub1, lowerOrig1); + Value lowerValid0 = + b.create(loc, checkLowerBound0, zeroInt, lower0); + Value lowerValid1 = + b.create(loc, checkLowerBound1, zeroInt, lower1); Value upper0 = b.create(loc, int64type, lower0, oneInt); Value upper1 = b.create(loc, int64type, lower1, oneInt); - Value notValid0 = rewriter.create( + Value notValidUpper0 = rewriter.create( loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); - Value notValid1 = rewriter.create( + Value notValidUpper1 = rewriter.create( loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); Value upperValid0 = - b.create(loc, notValid0, lower0, upper0); + b.create(loc, notValidUpper0, lower0, upper0); Value upperValid1 = - b.create(loc, notValid1, lower1, upper1); + b.create(loc, notValidUpper1, lower1, upper1); Value lw0 = - b.create(loc, b.getIndexType(), lower0); + b.create(loc, b.getIndexType(), lowerValid0); Value lw1 = - b.create(loc, b.getIndexType(), lower1); + b.create(loc, b.getIndexType(), lowerValid1); Value up0 = b.create(loc, b.getIndexType(), upperValid0); Value up1 = @@ -2604,23 +2631,31 @@ public: Value N = b.create(loc, 0); Value C = b.create(loc, 1); Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); + Value result00a = b.create(loc, checkLowerBound0, + zeroFloat, result00); + Value result00b = b.create(loc, checkLowerBound1, + zeroFloat, result00a); Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); - Value result01a = - b.create(loc, notValid1, zeroFloat, result01); + Value result01a = b.create(loc, notValidUpper1, + zeroFloat, result01); + Value result01b = b.create(loc, checkLowerBound0, + zeroFloat, result01a); Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); - Value result10a = - b.create(loc, notValid0, zeroFloat, result10); + Value result10a = b.create(loc, notValidUpper0, + zeroFloat, result10); + Value result10b = b.create(loc, checkLowerBound1, + zeroFloat, result10a); Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); - Value result11a = - b.create(loc, notValid0, zeroFloat, result11); - Value result11b = - b.create(loc, notValid1, zeroFloat, result11a); + Value result11a = b.create(loc, notValidUpper0, + zeroFloat, result11); + Value result11b = b.create(loc, notValidUpper1, + zeroFloat, result11a); Value lw0a = b.create(loc, floatType, lower0); Value lw1a = b.create(loc, floatType, lower1); Value d1 = b.create(loc, result0, lw0a); Value d0 = b.create(loc, result1, lw1a); - Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); - Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); + Value resultScaled0 = lambdaInter(b, loc, result00b, result01b, d0); + Value resultScaled1 = lambdaInter(b, loc, result10b, result11b, d0); Value resultScaled = lambdaInter(b, loc, resultScaled0, resultScaled1, d1); b.create(loc, resultScaled); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py index 2960041bd..815f64eed 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py @@ -35,7 +35,7 @@ class GridSamplerBasic1(torch.nn.Module): def GridSamplerBasic1_basic( module, tu: TestUtils): inp = torch.rand(7,8,12,4) - grd = torch.rand(7,11,13,2)*2-1 + grd = torch.rand(7,11,13,2)*2.0-1.0 module.forward(inp, grd) @@ -69,3 +69,34 @@ def GridSamplerBasic2_basic( grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) module.forward(inp, grd) + +class GridSamplerBasic3(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 4, 4], torch.float32, True), + ([1, 1, 3, 2], torch.float32, True) + ]) + def forward(self, x, g): + interpolation_mode=0, + padding_mode=0, + align_corners=False, + tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], + padding_mode[0], align_corners[0]) + return tRes + +@register_test_case( + module_factory=lambda: GridSamplerBasic3()) +def GridSamplerBasic3_basic( + module, tu: TestUtils): + inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor) + grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) + module.forward(inp, grd) + diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 4aa3716de..76b7e11c2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -539,14 +539,26 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- -// CHECK-LABEL: @test_grid_sampler +// CHECK-LABEL: @test_grid_sampler01 // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 // CHECK: %[[B0:.*]] = torch.constant.bool false // CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> -func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> - return %4 : !torch.vtensor<[?,?,?,?],f32> +func.func @test_grid_sampler01(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: @test_grid_sampler02 +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[B0:.*]] = torch.constant.bool true +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %0 : !torch.vtensor<[?,?,?,?],f32> } // ----- diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir index d392860fa..40a2dae45 100644 --- a/test/Conversion/TorchToLinalg/gridsampler.mlir +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -57,4 +57,33 @@ func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte %int1 = torch.constant.int 0 %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32> -} \ No newline at end of file +} + +// ----- + +// CHECK-LABEL: func @grid_sampler3 +// CHECK: #map +// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 +// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 +// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 +// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 +// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 +// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 +// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 +// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 +// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 +// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 +// CHECK-DAG: linalg.yield %[[X50]] : f32 +// CHECK: } -> tensor +// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %false = torch.constant.bool 1 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} +