diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b4324a7f7..55ddc65b3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); Value paddingMode = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - paddingModeInt)); + binder.getLoc(), paddingModeInt); bool alignMode = align; Value alignCorners = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5ecf58947..957937180 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2574,22 +2574,23 @@ public: return b.create(loc, xMaxZero, SizeSubOne); }; - auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode, + auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode, Value x, Value SizeSubOne) -> Value { - Value border = lambdaBorder(b, loc, x, SizeSubOne); - Value zeroInt = - b.create(loc, b.getIntegerAttr(int64type, 0)); - Value isZero = b.create(loc, arith::CmpIPredicate::eq, - paddingMode, zeroInt); + // Border + if (paddingMode == 1) { + return lambdaBorder(b, loc, x, SizeSubOne); + } - return b.create(loc, isZero, x, border); + return x; }; auto resultType = cast( getTypeConverter()->convertType(op.getResult().getType())); Value alignCorners = adaptor.getAlignCorners(); Value interMode = adaptor.getInterpolationMode(); - Value paddingMode = adaptor.getPaddingMode(); + + int64_t paddingModeInt; + matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt)); SmallVector dynamicSizes{}; if (resultType.isDynamicDim(0)) @@ -2623,9 +2624,9 @@ public: Value unnorm1 = b.create(loc, gPlusMul1, gr1HalfSelect); Value result0 = - lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d); + lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d); Value result1 = - lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d); + lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d); Value checkLowerBound0 = b.create( loc, arith::CmpFPredicate::OLT, result0, zeroFloat); Value checkLowerBound1 = b.create(