[onnx][torch][linalg] Implementing align-corner modes for gridsampler (#3171)

Align corner modes which select what the corners mean. 
Either the center of the corner points or the edges of the edge points.

---------

Co-authored-by: Rob Suderman <rob.suderman@gmail.com>
pull/3183/head
Andreas Falkenberg 2024-04-17 13:38:19 -07:00 committed by GitHub
parent 3aa81f78d8
commit b66eabd492
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 136 additions and 30 deletions

View File

@ -140,9 +140,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
if (binder.s64IntegerAttr(align, "align_corners", 0)) if (binder.s64IntegerAttr(align, "align_corners", 0))
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"align_corners bind failure"); "align_corners bind failure");
if (align != 1)
return rewriter.notifyMatchFailure(
binder.op, "currently only align_corners = 1 supported");
Value interpolationMode = rewriter.create<Torch::ConstantIntOp>( Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), binder.getLoc(), rewriter.getType<Torch::IntType>(),
@ -150,9 +147,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value paddingMode = rewriter.create<Torch::ConstantIntOp>( Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
bool alignMode = align;
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>( Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(), binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false)); rewriter.getBoolAttr(alignMode));
rewriter.replaceOpWithNewOp<Torch::AtenGridSamplerOp>( rewriter.replaceOpWithNewOp<Torch::AtenGridSamplerOp>(
binder.op, resultType, input, grid, interpolationMode, paddingMode, binder.op, resultType, input, grid, interpolationMode, paddingMode,

View File

@ -2565,6 +2565,7 @@ public:
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1)); resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 1));
if (resultType.isDynamicDim(3)) if (resultType.isDynamicDim(3))
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2)); resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
Value alignCorners = adaptor.getAlignCorners();
Value resultFinal = Value resultFinal =
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize); rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
auto sGrid = rewriter.create<linalg::GenericOp>( auto sGrid = rewriter.create<linalg::GenericOp>(
@ -2573,30 +2574,56 @@ public:
[&](OpBuilder &b, Location loc, ValueRange args) { [&](OpBuilder &b, Location loc, ValueRange args) {
Value gr0 = args[1]; Value gr0 = args[1];
Value gr1 = args[0]; Value gr1 = args[0];
Value gr0Half = b.create<arith::DivFOp>(loc, gr0, twoFloat);
Value gr1Half = b.create<arith::DivFOp>(loc, gr1, twoFloat);
Value gr0HalfSelect =
b.create<arith::SelectOp>(loc, alignCorners, zeroFloat, gr0Half);
Value gr1HalfSelect =
b.create<arith::SelectOp>(loc, alignCorners, zeroFloat, gr1Half);
Value gplus0 = b.create<arith::AddFOp>(loc, gr0, oneFloat); Value gplus0 = b.create<arith::AddFOp>(loc, gr0, oneFloat);
Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat); Value gplus1 = b.create<arith::AddFOp>(loc, gr1, oneFloat);
Value result0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e); Value gPlusMul0 = b.create<arith::MulFOp>(loc, gplus0, innerDim0e);
Value result1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e); Value gPlusMul1 = b.create<arith::MulFOp>(loc, gplus1, innerDim1e);
Value lower0 = b.create<arith::FPToSIOp>(loc, int64type, result0); Value result0 =
Value lower1 = b.create<arith::FPToSIOp>(loc, int64type, result1); b.create<arith::AddFOp>(loc, gPlusMul0, gr0HalfSelect);
Value result1 =
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
Value checkLowerBound0 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
Value checkLowerBound1 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result1, zeroFloat);
Value lowerOrig0 = b.create<arith::FPToSIOp>(loc, int64type, result0);
Value lowerOrig1 = b.create<arith::FPToSIOp>(loc, int64type, result1);
Value zeroInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
Value oneInt = Value oneInt =
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 1)); b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 1));
Value lowerSub0 = b.create<arith::SubIOp>(loc, lowerOrig0, oneInt);
Value lowerSub1 = b.create<arith::SubIOp>(loc, lowerOrig1, oneInt);
Value lower0 = b.create<arith::SelectOp>(loc, checkLowerBound0,
lowerSub0, lowerOrig0);
Value lower1 = b.create<arith::SelectOp>(loc, checkLowerBound1,
lowerSub1, lowerOrig1);
Value lowerValid0 =
b.create<arith::SelectOp>(loc, checkLowerBound0, zeroInt, lower0);
Value lowerValid1 =
b.create<arith::SelectOp>(loc, checkLowerBound1, zeroInt, lower1);
Value upper0 = Value upper0 =
b.create<arith::AddIOp>(loc, int64type, lower0, oneInt); b.create<arith::AddIOp>(loc, int64type, lower0, oneInt);
Value upper1 = Value upper1 =
b.create<arith::AddIOp>(loc, int64type, lower1, oneInt); b.create<arith::AddIOp>(loc, int64type, lower1, oneInt);
Value notValid0 = rewriter.create<arith::CmpIOp>( Value notValidUpper0 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); loc, arith::CmpIPredicate::sgt, upper0, innerDim0c);
Value notValid1 = rewriter.create<arith::CmpIOp>( Value notValidUpper1 = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); loc, arith::CmpIPredicate::sgt, upper1, innerDim1c);
Value upperValid0 = Value upperValid0 =
b.create<arith::SelectOp>(loc, notValid0, lower0, upper0); b.create<arith::SelectOp>(loc, notValidUpper0, lower0, upper0);
Value upperValid1 = Value upperValid1 =
b.create<arith::SelectOp>(loc, notValid1, lower1, upper1); b.create<arith::SelectOp>(loc, notValidUpper1, lower1, upper1);
Value lw0 = Value lw0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lower0); b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowerValid0);
Value lw1 = Value lw1 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), lower1); b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowerValid1);
Value up0 = Value up0 =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid0); b.create<arith::IndexCastOp>(loc, b.getIndexType(), upperValid0);
Value up1 = Value up1 =
@ -2604,23 +2631,31 @@ public:
Value N = b.create<linalg::IndexOp>(loc, 0); Value N = b.create<linalg::IndexOp>(loc, 0);
Value C = b.create<linalg::IndexOp>(loc, 1); Value C = b.create<linalg::IndexOp>(loc, 1);
Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1);
Value result00a = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result00);
Value result00b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result00a);
Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1);
Value result01a = Value result01a = b.create<arith::SelectOp>(loc, notValidUpper1,
b.create<arith::SelectOp>(loc, notValid1, zeroFloat, result01); zeroFloat, result01);
Value result01b = b.create<arith::SelectOp>(loc, checkLowerBound0,
zeroFloat, result01a);
Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1);
Value result10a = Value result10a = b.create<arith::SelectOp>(loc, notValidUpper0,
b.create<arith::SelectOp>(loc, notValid0, zeroFloat, result10); zeroFloat, result10);
Value result10b = b.create<arith::SelectOp>(loc, checkLowerBound1,
zeroFloat, result10a);
Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1);
Value result11a = Value result11a = b.create<arith::SelectOp>(loc, notValidUpper0,
b.create<arith::SelectOp>(loc, notValid0, zeroFloat, result11); zeroFloat, result11);
Value result11b = Value result11b = b.create<arith::SelectOp>(loc, notValidUpper1,
b.create<arith::SelectOp>(loc, notValid1, zeroFloat, result11a); zeroFloat, result11a);
Value lw0a = b.create<arith::SIToFPOp>(loc, floatType, lower0); Value lw0a = b.create<arith::SIToFPOp>(loc, floatType, lower0);
Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1); Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1);
Value d1 = b.create<arith::SubFOp>(loc, result0, lw0a); Value d1 = b.create<arith::SubFOp>(loc, result0, lw0a);
Value d0 = b.create<arith::SubFOp>(loc, result1, lw1a); Value d0 = b.create<arith::SubFOp>(loc, result1, lw1a);
Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); Value resultScaled0 = lambdaInter(b, loc, result00b, result01b, d0);
Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); Value resultScaled1 = lambdaInter(b, loc, result10b, result11b, d0);
Value resultScaled = Value resultScaled =
lambdaInter(b, loc, resultScaled0, resultScaled1, d1); lambdaInter(b, loc, resultScaled0, resultScaled1, d1);
b.create<linalg::YieldOp>(loc, resultScaled); b.create<linalg::YieldOp>(loc, resultScaled);

View File

@ -35,7 +35,7 @@ class GridSamplerBasic1(torch.nn.Module):
def GridSamplerBasic1_basic( def GridSamplerBasic1_basic(
module, tu: TestUtils): module, tu: TestUtils):
inp = torch.rand(7,8,12,4) inp = torch.rand(7,8,12,4)
grd = torch.rand(7,11,13,2)*2-1 grd = torch.rand(7,11,13,2)*2.0-1.0
module.forward(inp, grd) module.forward(inp, grd)
@ -69,3 +69,34 @@ def GridSamplerBasic2_basic(
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
module.forward(inp, grd) module.forward(inp, grd)
class GridSamplerBasic3(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([1, 1, 4, 4], torch.float32, True),
([1, 1, 3, 2], torch.float32, True)
])
def forward(self, x, g):
interpolation_mode=0,
padding_mode=0,
align_corners=False,
tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0],
padding_mode[0], align_corners[0])
return tRes
@register_test_case(
module_factory=lambda: GridSamplerBasic3())
def GridSamplerBasic3_basic(
module, tu: TestUtils):
inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320],
[0.3074, 0.6341, 0.4901, 0.8964],
[0.4556, 0.6323, 0.3489, 0.4017],
[0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor)
grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor)
module.forward(inp, grd)

View File

@ -539,14 +539,26 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
// ----- // -----
// CHECK-LABEL: @test_grid_sampler // CHECK-LABEL: @test_grid_sampler01
// CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[B0:.*]] = torch.constant.bool false // CHECK: %[[B0:.*]] = torch.constant.bool false
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> // CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_grid_sampler01(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> %0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32> return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: @test_grid_sampler02
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[B0:.*]] = torch.constant.bool true
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
} }
// ----- // -----

View File

@ -58,3 +58,32 @@ func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32>
} }
// -----
// CHECK-LABEL: func @grid_sampler3
// CHECK: #map
// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32
// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32
// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32
// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32
// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32
// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32
// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32
// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32
// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32
// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32
// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32
// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32
// CHECK-DAG: linalg.yield %[[X50]] : f32
// CHECK: } -> tensor<?x?x?x?xf32>
// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32>
func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
%false = torch.constant.bool 1
%int0 = torch.constant.int 0
%int1 = torch.constant.int 0
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}