[OnnxToTorch][GridSample] Add support for border padding mode

pull/3819/head
Atri Sarkar 2024-10-25 21:32:01 +05:30
parent d76d2b689c
commit 1e09ab2977
1 changed files with 11 additions and 3 deletions

View File

@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
}
std::string padding;
int64_t paddingModeInt;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op,
"padding_mode bind failure");
if (padding != "zeros")
if (padding == "zeros") {
paddingModeInt = 0;
} else if (padding == "border") {
paddingModeInt = 1;
} else {
return rewriter.notifyMatchFailure(
binder.op, "currently only padding_mode : zeros supported");
binder.op,
"currently only padding_mode : zeros and border supported");
}
int64_t align;
if (binder.s64IntegerAttr(align, "align_corners", 0))
return rewriter.notifyMatchFailure(binder.op,
@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
paddingModeInt));
bool alignMode = align;
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(