mirror of https://github.com/llvm/torch-mlir
parent
3ae0446aa2
commit
3395df830c
|
@ -163,9 +163,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), iModeInt));
|
||||||
|
|
||||||
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
|
Value paddingMode = rewriter.create<Torch::ConstantIntOp>(
|
||||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
binder.getLoc(), paddingModeInt);
|
||||||
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
|
|
||||||
paddingModeInt));
|
|
||||||
|
|
||||||
bool alignMode = align;
|
bool alignMode = align;
|
||||||
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
|
Value alignCorners = rewriter.create<Torch::ConstantBoolOp>(
|
||||||
|
|
|
@ -2574,22 +2574,23 @@ public:
|
||||||
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
|
return b.create<arith::MinimumFOp>(loc, xMaxZero, SizeSubOne);
|
||||||
};
|
};
|
||||||
|
|
||||||
auto lambdaPadding = [&](OpBuilder &b, Location loc, Value paddingMode,
|
auto lambdaPadding = [&](OpBuilder &b, Location loc, int64_t paddingMode,
|
||||||
Value x, Value SizeSubOne) -> Value {
|
Value x, Value SizeSubOne) -> Value {
|
||||||
Value border = lambdaBorder(b, loc, x, SizeSubOne);
|
// Border
|
||||||
Value zeroInt =
|
if (paddingMode == 1) {
|
||||||
b.create<arith::ConstantOp>(loc, b.getIntegerAttr(int64type, 0));
|
return lambdaBorder(b, loc, x, SizeSubOne);
|
||||||
Value isZero = b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq,
|
}
|
||||||
paddingMode, zeroInt);
|
|
||||||
|
|
||||||
return b.create<arith::SelectOp>(loc, isZero, x, border);
|
return x;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto resultType = cast<RankedTensorType>(
|
auto resultType = cast<RankedTensorType>(
|
||||||
getTypeConverter()->convertType(op.getResult().getType()));
|
getTypeConverter()->convertType(op.getResult().getType()));
|
||||||
Value alignCorners = adaptor.getAlignCorners();
|
Value alignCorners = adaptor.getAlignCorners();
|
||||||
Value interMode = adaptor.getInterpolationMode();
|
Value interMode = adaptor.getInterpolationMode();
|
||||||
Value paddingMode = adaptor.getPaddingMode();
|
|
||||||
|
int64_t paddingModeInt;
|
||||||
|
matchPattern(op.getPaddingMode(), m_TorchConstantInt(&paddingModeInt));
|
||||||
|
|
||||||
SmallVector<Value> dynamicSizes{};
|
SmallVector<Value> dynamicSizes{};
|
||||||
if (resultType.isDynamicDim(0))
|
if (resultType.isDynamicDim(0))
|
||||||
|
@ -2623,9 +2624,9 @@ public:
|
||||||
Value unnorm1 =
|
Value unnorm1 =
|
||||||
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
|
b.create<arith::AddFOp>(loc, gPlusMul1, gr1HalfSelect);
|
||||||
Value result0 =
|
Value result0 =
|
||||||
lambdaPadding(b, loc, paddingMode, unnorm0, innerDim0d);
|
lambdaPadding(b, loc, paddingModeInt, unnorm0, innerDim0d);
|
||||||
Value result1 =
|
Value result1 =
|
||||||
lambdaPadding(b, loc, paddingMode, unnorm1, innerDim1d);
|
lambdaPadding(b, loc, paddingModeInt, unnorm1, innerDim1d);
|
||||||
Value checkLowerBound0 = b.create<arith::CmpFOp>(
|
Value checkLowerBound0 = b.create<arith::CmpFOp>(
|
||||||
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
|
loc, arith::CmpFPredicate::OLT, result0, zeroFloat);
|
||||||
Value checkLowerBound1 = b.create<arith::CmpFOp>(
|
Value checkLowerBound1 = b.create<arith::CmpFOp>(
|
||||||
|
|
Loading…
Reference in New Issue