From 1e09ab2977788a94a876dd1d942ec8520c2ceecf Mon Sep 17 00:00:00 2001 From: Atri Sarkar Date: Fri, 25 Oct 2024 21:32:01 +0530 Subject: [PATCH] [OnnxToTorch][GridSample] Add support for border padding mode --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1f3ff7ac2..b4324a7f7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } std::string padding; + int64_t paddingModeInt; if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) return rewriter.notifyMatchFailure(binder.op, "padding_mode bind failure"); - if (padding != "zeros") + if (padding == "zeros") { + paddingModeInt = 0; + } else if (padding == "border") { + paddingModeInt = 1; + } else { return rewriter.notifyMatchFailure( - binder.op, "currently only padding_mode : zeros supported"); + binder.op, + "currently only padding_mode : zeros and border supported"); + } int64_t align; if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, @@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value paddingMode = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + paddingModeInt)); bool alignMode = align; Value alignCorners = rewriter.create(