[onnx][torch] Lower `onnx.grid_sampler` to the `torch` equivalents (#2952)

This is the lowering of gridsampler from onnx to torch using our prior
implementation of AtenGridSamplerOp.
Here are several checks for cornercases implemented. We may decide to
have part of these checks in AtenGridSamplerOp instead of the onnx
lowering portion.
pull/2967/head
Andreas Falkenberg 2024-02-28 13:52:15 -08:00 committed by GitHub
parent e48fe45886
commit 5437f32193
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 79 additions and 0 deletions

View File

@ -92,6 +92,73 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
operand, vApproximate);
return success();
});
patterns.onOp(
"GridSample", 20,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
Value grid;
if (binder.tensorOperandAtIndex(input, 0) ||
binder.tensorOperandAtIndex(grid, 1) ||
binder.tensorResultType(resultType))
return rewriter.notifyMatchFailure(
binder.op, "operand grid_sampler bind failure");
auto inputTensorType = input.getType().cast<Torch::ValueTensorType>();
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
uint32_t inputRank = inputShape.size();
auto gridTensorType = grid.getType().cast<Torch::ValueTensorType>();
ArrayRef<int64_t> gridShape = gridTensorType.getSizes();
uint32_t gridRank = gridShape.size();
if (inputRank != 4)
return rewriter.notifyMatchFailure(binder.op,
"only input rank 4 supported");
if (gridRank != 4)
return rewriter.notifyMatchFailure(binder.op,
"only grid rank 4 supported");
if (inputShape[0] != gridShape[0])
return rewriter.notifyMatchFailure(
binder.op, "N must be same for input and grid");
if (gridShape[3] != 2)
return rewriter.notifyMatchFailure(binder.op,
"gridShape[3] expected to be 2");
std::string mode;
if (binder.customOpNameStringAttr(mode, "mode", "bilinear"))
return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
if (mode != "bilinear")
return rewriter.notifyMatchFailure(
binder.op, "currently only mode : bilinear supported");
std::string padding;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
"padding_mode bind failure");
if (padding != "zeros")
return rewriter.notifyMatchFailure(
binder.op, "currently only padding_mode : zeros supported");
int64_t align;
if (binder.s64IntegerAttr(align, "align_corners", 0))
return rewriter.notifyMatchFailure(binder.op,
"align_corners bind failure");
if (align != 0)
return rewriter.notifyMatchFailure(
binder.op, "currently only align_corners : 0 supported");
Value interpolationMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
binder.getLoc(), rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenGridSamplerOp>(
binder.op, resultType, input, grid, interpolationMode, paddingMode,
alignCorners);
return success();
});
patterns.onOp("Less", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;

View File

@ -395,6 +395,18 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
// -----
// CHECK-LABEL: @test_grid_sampler
// CHECK: %[[INT0:.*]] = torch.constant.int 0
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[B0:.*]] = torch.constant.bool false
// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32>
func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32>
return %4 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @test_less_or_equal
func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32>