mirror of https://github.com/llvm/torch-mlir
[OnnxToTorch][GridSample] Add support for border padding mode
parent
d76d2b689c
commit
1e09ab2977
|
@ -140,12 +140,19 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
}
|
||||
|
||||
std::string padding;
|
||||
int64_t paddingModeInt;
|
||||
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"padding_mode bind failure");
|
||||
if (padding != "zeros")
|
||||
if (padding == "zeros") {
|
||||
paddingModeInt = 0;
|
||||
} else if (padding == "border") {
|
||||
paddingModeInt = 1;
|
||||
} else {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "currently only padding_mode : zeros supported");
|
||||
binder.op,
|
||||
"currently only padding_mode : zeros and border supported");
|
||||
}
|
||||
int64_t align;
|
||||
if (binder.s64IntegerAttr(align, "align_corners", 0))
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
|
@ -157,7 +164,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
|||
|
||||
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
||||
paddingModeInt));
|
||||
|
||||
bool alignMode = align;
|
||||
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
|
||||
|
|
Loading…
Reference in New Issue