[onnx] Update DefaultDomainGtoP.cpp gridsampler (#3228)

Gridsampler
In onnx the interpolation mode is called 'linear' whereas in pytorch it
is called 'bilinear'. This led to the problem that everything other than
'bilinear' was rejected. It needed to be changed to linear.
pull/3238/head
Andreas Falkenberg 2024-04-25 18:07:05 -07:00 committed by GitHub
parent ac11ec796d
commit cd33d8b011
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 3 deletions

View File

@ -124,11 +124,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,
"gridShape[3] expected to be 2"); "gridShape[3] expected to be 2");
std::string mode; std::string mode;
if (binder.customOpNameStringAttr(mode, "mode", "bilinear")) if (binder.customOpNameStringAttr(mode, "mode", "linear"))
return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); return rewriter.notifyMatchFailure(binder.op, "mode bind failure");
if (mode != "bilinear") if (mode != "linear" && mode != "bilinear")
return rewriter.notifyMatchFailure( return rewriter.notifyMatchFailure(
binder.op, "currently only mode : bilinear supported"); binder.op, "currently only mode : linear supported");
std::string padding; std::string padding;
if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros"))
return rewriter.notifyMatchFailure(binder.op, return rewriter.notifyMatchFailure(binder.op,