Simplify paddingMode lowering

Evaluate paddingMode at compile time
pull/3819/head
Atri Sarkar 2024-11-17 21:53:29 +05:30
parent 3ae0446aa2
commit 3395df830c
2 changed files with 12 additions and 13 deletions

View File

@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt)); rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
Value paddingMode = rewriter.create<Torch::ConstantIntOp>( Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), binder.getLoc(), paddingModeInt);
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
paddingModeInt));
bool alignMode = align; bool alignMode = align;
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>( Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(

View File

@ -2574,22 +2574,23 @@ public:
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne); return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
}; };
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode, auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
Value x, Value SizeSubOne) -> Value { Value x, Value SizeSubOne) -> Value {
Value border = lambdaBorder(b, loc, x, SizeSubOne); // Border
Value zeroInt = if (paddingMode == 1) {
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0)); return lambdaBorder(b, loc, x, SizeSubOne);
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, }
paddingMode, zeroInt);
return b.create<arith::SelectOp>(loc, isZero, x, border); return x;
}; };
auto resultType = cast<RankedTensorType>( auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op.getResult().getType())); getTypeConverter()->convertType(op.getResult().getType()));
Value alignCorners = adaptor.getAlignCorners(); Value alignCorners = adaptor.getAlignCorners();
Value interMode = adaptor.getInterpolationMode(); Value interMode = adaptor.getInterpolationMode();
Value paddingMode = adaptor.getPaddingMode();
int64_t paddingModeInt;
matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt));
SmallVector<Value> dynamicSizes{}; SmallVector<Value> dynamicSizes{};
if (resultType.isDynamicDim(0)) if (resultType.isDynamicDim(0))
@ -2623,9 +2624,9 @@ public:
Value unnorm1 = Value unnorm1 =
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect); b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
Value result0 = Value result0 =
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d); lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d);
Value result1 = Value result1 =
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d); lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d);
Value checkLowerBound0 = b.create<arith::CmpFOp>( Value checkLowerBound0 = b.create<arith::CmpFOp>(
loc, arith::CmpFPredicate::OLT, result0, zeroFloat); loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
Value checkLowerBound1 = b.create<arith::CmpFOp>( Value checkLowerBound1 = b.create<arith::CmpFOp>(