mirror of https://github.com/llvm/torch-mlir
[onnx][torch] Lower `onnx.grid_sampler` to the `torch` equivalents (#2952)
This is the lowering of gridsampler from onnx to torch using our prior implementation of AtenGridSamplerOp. Here are several checks for cornercases implemented. We may decide to have part of these checks in AtenGridSamplerOp instead of the onnx lowering portion.pull/2967/head
parent
e48fe45886
commit
5437f32193
|
@ -92,6 +92,73 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
operand, vApproximate);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"GridSample", 20,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
Value input;
|
||||
Value grid;
|
||||
if (binder.tensorOperandAtIndex(input, 0) ||
|
||||
binder.tensorOperandAtIndex(grid, 1) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "operand grid_sampler bind failure");
|
||||
|
||||
auto inputTensorType = input.getType().cast<Torch::ValueTensorType>();
|
||||
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
|
||||
uint32_t inputRank = inputShape.size();
|
||||
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>();
|
||||
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
|
||||
uint32_t gridRank = gridShape.size();
|
||||
|
||||
if (inputRank != 4)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"only input rank 4 supported");
|
||||
if (gridRank != 4)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"only grid rank 4 supported");
|
||||
if (inputShape[0] != gridShape[0])
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "N must be same for input and grid");
|
||||
if (gridShape[3] != 2)
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"gridShape[3] expected to be 2");
|
||||
std::string mode;
|
||||
if (binder.customOpNameStringAttr(mode, "mode", "bilinear"))
|
||||
return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
|
||||
if (mode != "bilinear")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "currently only mode : bilinear supported");
|
||||
std::string padding;
|
||||
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"padding_mode bind failure");
|
||||
if (padding != "zeros")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "currently only padding_mode : zeros supported");
|
||||
int64_t align;
|
||||
if (binder.s64IntegerAttr(align, "align_corners", 0))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"align_corners bind failure");
|
||||
if (align != 0)
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "currently only align_corners : 0 supported");
|
||||
|
||||
Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
|
||||
rewriter.getBoolAttr(false));
|
||||
|
||||
rewriter.replaceOpWithNewOp<Torch::AtenGridSamplerOp>(
|
||||
binder.op, resultType, input, grid, interpolationMode, paddingMode,
|
||||
alignCorners);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp("Less", 13,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
Torch::ValueTensorType resultType;
|
||||
|
|
|
@ -395,6 +395,18 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_grid_sampler
|
||||
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
|
||||
// CHECK: %[[B0:.*]] = torch.constant.bool false
|
||||
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
|
||||
func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
%4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||
return %4 : !torch.vtensor<[?,?,?,?],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @test_less_or_equal
|
||||
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>
|
||||
|
|
Loading…
Reference in New Issue