mirror of https://github.com/llvm/torch-mlir
[onnx] Gridsampler addition of nearest mode (#3320)
Added nearest neighbor selection for onnx.Gridsamplerpull/3328/head
parent
4b24909427
commit
adafd51823
|
@ -123,12 +123,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
if (gridShape[3] != 2)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"gridShape[3] expected to be 2");
|
||||
std::string mode;
|
||||
if (binder.customOpNameStringAttr(mode, "mode", "linear"))
|
||||
std::string iModeString;
|
||||
int64_t iModeInt;
|
||||
if (binder.customOpNameStringAttr(iModeString, "mode", "linear"))
|
||||
return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
|
||||
if (mode != "linear" && mode != "bilinear")
|
||||
|
||||
if (iModeString == "linear" || iModeString == "bilinear") {
|
||||
iModeInt = 0;
|
||||
} else if (iModeString == "nearest") {
|
||||
iModeInt = 1;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "currently only mode : linear supported");
|
||||
binder.op, "currently only mode : linear and nearest supported");
|
||||
}
|
||||
|
||||
std::string padding;
|
||||
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
|
@ -143,7 +151,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
|
||||
|
||||
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
|
|
|
@ -2524,14 +2524,38 @@ public:
|
|||
Value result = b.create<tensor::ExtractOp>(loc, input, index);
|
||||
return result;
|
||||
};
|
||||
auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y,
|
||||
Value d) -> Value {
|
||||
|
||||
auto lambdaLinear = [&](OpBuilder &b, Location loc, Value x, Value y,
|
||||
Value d) -> Value {
|
||||
Value dm = b.create<arith::SubFOp>(loc, oneFloat, d);
|
||||
Value ra = b.create<arith::MulFOp>(loc, x, dm);
|
||||
Value rb = b.create<arith::MulFOp>(loc, y, d);
|
||||
Value res = b.create<arith::AddFOp>(loc, ra, rb);
|
||||
return res;
|
||||
};
|
||||
|
||||
auto lambdaNearest = [&](OpBuilder &b, Location loc, Value x, Value y,
|
||||
Value d) -> Value {
|
||||
Value halfConst = rewriter.create<arith::ConstantOp>(
|
||||
loc, rewriter.getFloatAttr(floatType, 0.5));
|
||||
Value checkClosest =
|
||||
b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT, d, halfConst);
|
||||
Value res = b.create<arith::SelectOp>(loc, checkClosest, x, y);
|
||||
return res;
|
||||
};
|
||||
|
||||
auto lambdaInterpolate = [&](OpBuilder &b, Location loc, Value iMode,
|
||||
Value x, Value y, Value d) -> Value {
|
||||
Value linear = lambdaLinear(b, loc, x, y, d);
|
||||
Value nearest = lambdaNearest(b, loc, x, y, d);
|
||||
Value zeroInt =
|
||||
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
|
||||
Value checkMode = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
||||
iMode, zeroInt);
|
||||
Value res = b.create<arith::SelectOp>(loc, checkMode, linear, nearest);
|
||||
return res;
|
||||
};
|
||||
|
||||
auto resultType = getTypeConverter()
|
||||
->convertType(op.getResult().getType())
|
||||
.cast<RankedTensorType>();
|
||||
|
@ -2545,6 +2569,7 @@ public:
|
|||
if (resultType.isDynamicDim(3))
|
||||
resultSize.push_back(rewriter.create<tensor::DimOp>(loc, grid, 2));
|
||||
Value alignCorners = adaptor.getAlignCorners();
|
||||
Value interMode = adaptor.getInterpolationMode();
|
||||
Value resultFinal =
|
||||
rewriter.create<tensor::EmptyOp>(loc, resultType, resultSize);
|
||||
auto sGrid = rewriter.create<linalg::GenericOp>(
|
||||
|
@ -2633,10 +2658,12 @@ public:
|
|||
Value lw1a = b.create<arith::SIToFPOp>(loc, floatType, lower1);
|
||||
Value d1 = b.create<arith::SubFOp>(loc, result0, lw0a);
|
||||
Value d0 = b.create<arith::SubFOp>(loc, result1, lw1a);
|
||||
Value resultScaled0 = lambdaInter(b, loc, result00b, result01b, d0);
|
||||
Value resultScaled1 = lambdaInter(b, loc, result10b, result11b, d0);
|
||||
Value resultScaled =
|
||||
lambdaInter(b, loc, resultScaled0, resultScaled1, d1);
|
||||
Value resultScaled0 =
|
||||
lambdaInterpolate(b, loc, interMode, result00b, result01b, d0);
|
||||
Value resultScaled1 =
|
||||
lambdaInterpolate(b, loc, interMode, result10b, result11b, d0);
|
||||
Value resultScaled = lambdaInterpolate(
|
||||
b, loc, interMode, resultScaled0, resultScaled1, d1);
|
||||
b.create<linalg::YieldOp>(loc, resultScaled);
|
||||
});
|
||||
rewriter.replaceOp(op, sGrid.getResults());
|
||||
|
|
|
@ -115,3 +115,41 @@ def GridSamplerBasic3_basic(module, tu: TestUtils):
|
|||
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||
).type(torch.FloatTensor)
|
||||
module.forward(inp, grd)
|
||||
|
||||
|
||||
class GridSamplerBasic4(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args(
|
||||
[None, ([1, 1, 4, 4], torch.float32, True), ([1, 1, 3, 2], torch.float32, True)]
|
||||
)
|
||||
def forward(self, x, g):
|
||||
interpolation_mode = (1,)
|
||||
padding_mode = (0,)
|
||||
align_corners = (False,)
|
||||
tRes = torch.ops.aten.grid_sampler(
|
||||
x, g, interpolation_mode[0], padding_mode[0], align_corners[0]
|
||||
)
|
||||
return tRes
|
||||
|
||||
|
||||
@register_test_case(module_factory=lambda: GridSamplerBasic4())
|
||||
def GridSamplerBasic4_basic(module, tu: TestUtils):
|
||||
inp = torch.tensor(
|
||||
[
|
||||
[
|
||||
[
|
||||
[0.4963, 0.7682, 0.0885, 0.1320],
|
||||
[0.3074, 0.6341, 0.4901, 0.8964],
|
||||
[0.4556, 0.6323, 0.3489, 0.4017],
|
||||
[0.0223, 0.1689, 0.2939, 0.5185],
|
||||
]
|
||||
]
|
||||
]
|
||||
).type(torch.FloatTensor)
|
||||
grd = torch.tensor(
|
||||
[[[[-0.3498, -0.8196], [-0.2127, 0.2138], [-0.6515, -0.0513]]]]
|
||||
).type(torch.FloatTensor)
|
||||
module.forward(inp, grd)
|
||||
|
|
|
@ -545,7 +545,7 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
|
|||
// CHECK: %[[B0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
|
||||
func.func @test_grid_sampler01(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.mode = "linear", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
|
@ -563,6 +563,18 @@ func.func @test_grid_sampler02(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !t
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_grid_sampler03
|
||||
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[B0:.*]] = torch.constant.bool true
|
||||
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT1]], %[[INT0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
|
||||
func.func @test_grid_sampler03(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%0 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.mode = "nearest", torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_less_or_equal
|
||||
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
|
|
|
@ -2,27 +2,32 @@
|
|||
|
||||
// CHECK: #map
|
||||
// CHECK-LABEL: func @grid_sampler
|
||||
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
|
||||
// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index
|
||||
// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index
|
||||
// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64
|
||||
// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64
|
||||
// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32
|
||||
// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32
|
||||
// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32
|
||||
// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32
|
||||
// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32>
|
||||
// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32
|
||||
// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index
|
||||
// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32>
|
||||
// CHECK-DAG: %[[X73:.*]] = arith.cmpi eq, %[[X3:.*]], %[[C27:.*]] : i64
|
||||
// CHECK-DAG: %[[X74:.*]] = arith.select %[[X73:.*]], %[[X70:.*]], %[[X72:.*]] : f32
|
||||
// CHECK-DAG: %[[X75:.*]] = arith.subf %[[Xcst_1:.*]], %[[X57:.*]] : f32
|
||||
// CHECK-DAG: %[[X76:.*]] = arith.mulf %[[X66:.*]], %[[X75:.*]] : f32
|
||||
// CHECK-DAG: %[[X77:.*]] = arith.mulf %[[X74:.*]], %[[X57:.*]] : f32
|
||||
// CHECK-DAG: %[[X78:.*]] = arith.addf %[[X76:.*]], %[[X77:.*]] : f32
|
||||
// CHECK-DAG: %[[C28:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[X79:.*]] = arith.cmpf olt, %[[X57:.*]], %[[X28:.*]] : f32
|
||||
// CHECK-DAG: %[[X80:.*]] = arith.select %[[X79:.*]], %[[X66:.*]], %[[X74:.*]] : f32
|
||||
// CHECK-DAG: %[[C29:.*]] = arith.constant 0 : i64
|
||||
// CHECK-DAG: %[[X81:.*]] = arith.cmpi eq, %[[X3:.*]], %[[C29:.*]] : i64
|
||||
// CHECK-DAG: %[[X82:.*]] = arith.select %[[X81:.*]], %[[X78:.*]], %[[X80:.*]] : f32
|
||||
// CHECK-DAG: linalg.yield %[[X82:.*]] : f32
|
||||
// CHECK-DAG: %[[X14:.*]] = torch_c.from_builtin_tensor %[[X13:.*]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
|
||||
func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%true = torch.constant.bool 0
|
||||
%int0 = torch.constant.int 0
|
||||
|
@ -35,22 +40,25 @@ func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vt
|
|||
|
||||
// CHECK-LABEL: func @grid_sampler2
|
||||
// CHECK: #map
|
||||
// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32
|
||||
// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32
|
||||
// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32
|
||||
// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32
|
||||
// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32
|
||||
// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32
|
||||
// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32
|
||||
// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32
|
||||
// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32
|
||||
// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32
|
||||
// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32
|
||||
// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32
|
||||
// CHECK-DAG: linalg.yield %[[X50]] : f32
|
||||
// CHECK: } -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK-DAG: %[[X70:.*]] = arith.addf %[[X68:.*]], %[[X69:.*]] : f32
|
||||
// CHECK-DAG: %[[X29:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[X71:.*]] = arith.cmpf olt, %[[X58:.*]], %[[X29:.*]] : f32
|
||||
// CHECK-DAG: %[[X72:.*]] = arith.select %[[X71:.*]], %[[X52:.*]], %[[X54:.*]] : f32
|
||||
// CHECK-DAG: %[[X30:.*]] = arith.constant 0 : i64
|
||||
// CHECK-DAG: %[[X73:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X30:.*]] : i64
|
||||
// CHECK-DAG: %[[X74:.*]] = arith.select %[[X73:.*]], %[[X70:.*]], %[[X72:.*]] : f32
|
||||
// CHECK-DAG: %[[X75:.*]] = arith.subf %[[X1:.*]], %[[X57:.*]] : f32
|
||||
// CHECK-DAG: %[[X76:.*]] = arith.mulf %[[X66:.*]], %[[X75:.*]] : f32
|
||||
// CHECK-DAG: %[[X77:.*]] = arith.mulf %[[X74:.*]], %[[X57:.*]] : f32
|
||||
// CHECK-DAG: %[[X78:.*]] = arith.addf %[[X76:.*]], %[[X77:.*]] : f32
|
||||
// CHECK-DAG: %[[X31:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[X79:.*]] = arith.cmpf olt, %[[X57:.*]], %[[X31:.*]] : f32
|
||||
// CHECK-DAG: %[[X80:.*]] = arith.select %[[X79:.*]], %[[X66:.*]], %[[X74:.*]] : f32
|
||||
// CHECK-DAG: %[[X32:.*]] = arith.constant 0 : i64
|
||||
// CHECK-DAG: %[[X81:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X32:.*]] : i64
|
||||
// CHECK-DAG: %[[X82:.*]] = arith.select %[[X81:.*]], %[[X78:.*]], %[[X80:.*]] : f32
|
||||
// CHECK-DAG: linalg.yield %[[X50:.*]] : f32
|
||||
// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%true = torch.constant.bool 0
|
||||
%int0 = torch.constant.int 0
|
||||
|
@ -64,21 +72,21 @@ func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte
|
|||
// CHECK-LABEL: func @grid_sampler3
|
||||
// CHECK: #map
|
||||
// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32
|
||||
// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32
|
||||
// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32
|
||||
// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32
|
||||
// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32
|
||||
// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32
|
||||
// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32
|
||||
// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32
|
||||
// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32
|
||||
// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32
|
||||
// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32
|
||||
// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32
|
||||
// CHECK-DAG: linalg.yield %[[X50]] : f32
|
||||
// CHECK: } -> tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor<?x?x?x?xf32> -> !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
// CHECK-DAG: %[[Y60:.*]] = arith.mulf %[[X48:.*]], %[[X59:.*]] : f32
|
||||
// CHECK-DAG: %[[Y61:.*]] = arith.mulf %[[X50:.*]], %[[X58:.*]] : f32
|
||||
// CHECK-DAG: %[[Y62:.*]] = arith.addf %[[X60:.*]], %[[X61:.*]] : f32
|
||||
// CHECK-DAG: %[[Y28:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[Y64:.*]] = arith.select %[[X63:.*]], %[[X48:.*]], %[[X50:.*]] : f32
|
||||
// CHECK-DAG: %[[Y29:.*]] = arith.constant 0 : i6
|
||||
// CHECK-DAG: %[[Y65:.*]] = arith.cmpi eq, %[[X3:.*]], %[[X28:.*]] : i64
|
||||
// CHECK-DAG: %[[Y66:.*]] = arith.select %[[X65:.*]], %[[X62:.*]], %[[X64:.*]] : f32
|
||||
// CHECK-DAG: %[[Y67:.*]] = arith.subf %[[X1:.*]], %[[X58:.*]] : f32
|
||||
// CHECK-DAG: %[[Y68:.*]] = arith.mulf %[[X52:.*]], %[[X67:.*]] : f32
|
||||
// CHECK-DAG: %[[Y69:.*]] = arith.mulf %[[X54:.*]], %[[X58:.*]] : f32
|
||||
// CHECK-DAG: %[[Y70:.*]] = arith.addf %[[X68:.*]], %[[X69:.*]] : f32
|
||||
// CHECK-DAG: %[[Y30:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK-DAG: %[[Y31:.*]] = arith.constant 0 : i64
|
||||
// CHECK: return %[[X12:.*]] : !torch.vtensor<[?,?,?,?],f32>
|
||||
func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%false = torch.constant.bool 1
|
||||
%int0 = torch.constant.int 0
|
||||
|
@ -86,3 +94,14 @@ func.func @grid_sampler3(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vte
|
|||
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %4 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @grid_sampler4
|
||||
func.func @grid_sampler4(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> {
|
||||
%false = torch.constant.bool 1
|
||||
%int0 = torch.constant.int 0
|
||||
%int1 = torch.constant.int 1
|
||||
%4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %4 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue