mirror of https://github.com/llvm/torch-mlir
OnnxToTorch bicubic interpolation (#3802)
(https://github.com/nod-ai/SHARK-TestSuite/pull/391) Repro (using SHARK TestSuite): 1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test` --------- Co-authored-by: zjgarvey <zjgarvey@gmail.com>pull/3759/head
parent
17c1985c4d
commit
889a836b3d
|
@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
llvm::SmallVector<Value> operands;
|
llvm::SmallVector<Value> operands;
|
||||||
std::string mode, nearest_mode, coordTfMode;
|
std::string mode, nearest_mode, coordTfMode;
|
||||||
int64_t antialias, exclude_outside;
|
int64_t antialias, exclude_outside;
|
||||||
float extrapolation_value;
|
float extrapolation_value, cubic_coeff_a;
|
||||||
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
|
||||||
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
|
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
|
||||||
|
@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
|
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
|
||||||
0.0) ||
|
0.0) ||
|
||||||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
|
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
|
||||||
"round_prefer_floor"))
|
"round_prefer_floor") ||
|
||||||
|
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
|
||||||
return failure();
|
return failure();
|
||||||
if (antialias != 0) {
|
if (antialias != 0) {
|
||||||
return rewriter.notifyMatchFailure(
|
return rewriter.notifyMatchFailure(
|
||||||
|
@ -2976,6 +2977,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
"except asymmetric and half_pixel");
|
"except asymmetric and half_pixel");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mode == "cubic" && cubic_coeff_a != -0.75) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: cubic coeff must be -0.75");
|
||||||
|
}
|
||||||
|
|
||||||
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
|
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
|
||||||
.getSizes()
|
.getSizes()
|
||||||
.size();
|
.size();
|
||||||
|
@ -2991,8 +2997,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
Value alignCorners =
|
Value alignCorners =
|
||||||
coordTfMode == "align_corners" ? cstTrue : cstFalse;
|
coordTfMode == "align_corners" ? cstTrue : cstFalse;
|
||||||
if (mode == "cubic") {
|
if (mode == "cubic") {
|
||||||
return rewriter.notifyMatchFailure(binder.op,
|
std::string modeStr = "cubic";
|
||||||
"unimplemented: bicubic mode");
|
if (coordTfMode != "half_pixel")
|
||||||
|
modeStr = modeStr + "_" + coordTfMode;
|
||||||
|
modeStrValue =
|
||||||
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
|
||||||
}
|
}
|
||||||
// supported modes:
|
// supported modes:
|
||||||
// bilinear (half_pixel), bilinear with align_corners,
|
// bilinear (half_pixel), bilinear with align_corners,
|
||||||
|
|
|
@ -2683,7 +2683,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static Value NearestInterpolate(OpBuilder &b, Location loc,
|
static Value nearestInterpolate(OpBuilder &b, Location loc,
|
||||||
SmallVector<Value> outputSizes, Value input,
|
SmallVector<Value> outputSizes, Value input,
|
||||||
SmallVector<Value> inputSizes,
|
SmallVector<Value> inputSizes,
|
||||||
SmallVector<Value> scaleValues,
|
SmallVector<Value> scaleValues,
|
||||||
|
@ -2771,12 +2771,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
|
||||||
return retVal;
|
return retVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Value BilinearInterpolate(OpBuilder &b,
|
static SmallVector<Value> coordinateTransform(
|
||||||
Aten__InterpolateSizeListScaleListOp op,
|
OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc,
|
||||||
Location loc, SmallVector<Value> outputSizes,
|
SmallVector<Value> outputSizes, Value input, SmallVector<Value> inputSizes,
|
||||||
Value input, SmallVector<Value> inputSizes,
|
SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool,
|
||||||
SmallVector<Value> scaleValues,
|
SmallVector<Value> indices, bool clip) {
|
||||||
std::string coordStr) {
|
|
||||||
unsigned dimOffset = 2;
|
unsigned dimOffset = 2;
|
||||||
auto inputType = cast<RankedTensorType>(input.getType());
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
auto inputRank = inputType.getRank();
|
auto inputRank = inputType.getRank();
|
||||||
|
@ -2785,15 +2785,7 @@ static Value BilinearInterpolate(OpBuilder &b,
|
||||||
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
|
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
|
||||||
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
|
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
|
||||||
|
|
||||||
bool alignCornersBool;
|
SmallVector<Value> proj;
|
||||||
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
|
|
||||||
|
|
||||||
SmallVector<Value> indices;
|
|
||||||
for (unsigned i = 0; i < inputRank; i++) {
|
|
||||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
|
||||||
}
|
|
||||||
|
|
||||||
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
|
|
||||||
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
|
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
|
||||||
// length_original
|
// length_original
|
||||||
Value inputFP =
|
Value inputFP =
|
||||||
|
@ -2856,6 +2848,7 @@ static Value BilinearInterpolate(OpBuilder &b,
|
||||||
outputSizeFP, cstOneFloat);
|
outputSizeFP, cstOneFloat);
|
||||||
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
|
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
|
||||||
}
|
}
|
||||||
|
if (clip) {
|
||||||
// preClip is the fp position inside the input image to extract from.
|
// preClip is the fp position inside the input image to extract from.
|
||||||
// clip to [0,inf)
|
// clip to [0,inf)
|
||||||
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
|
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
|
||||||
|
@ -2863,6 +2856,42 @@ static Value BilinearInterpolate(OpBuilder &b,
|
||||||
// clip to [0,length_original - 1].
|
// clip to [0,length_original - 1].
|
||||||
// proj is properly within the input image.
|
// proj is properly within the input image.
|
||||||
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
|
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
|
||||||
|
} else {
|
||||||
|
proj.push_back(preClip);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return proj;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value bilinearInterpolate(OpBuilder &b,
|
||||||
|
Aten__InterpolateSizeListScaleListOp op,
|
||||||
|
Location loc, SmallVector<Value> outputSizes,
|
||||||
|
Value input, SmallVector<Value> inputSizes,
|
||||||
|
SmallVector<Value> scaleValues,
|
||||||
|
std::string coordStr) {
|
||||||
|
unsigned dimOffset = 2;
|
||||||
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
|
||||||
|
|
||||||
|
bool alignCornersBool;
|
||||||
|
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
|
||||||
|
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (unsigned i = 0; i < inputRank; i++) {
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> proj, high, low, highFP, lowFP;
|
||||||
|
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
|
||||||
|
scaleValues, coordStr, alignCornersBool, indices,
|
||||||
|
true);
|
||||||
|
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
|
||||||
|
// length_original
|
||||||
|
Value inputFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
|
||||||
|
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
|
||||||
|
|
||||||
// for bilinear interpolation, we look for the nearest indices below and
|
// for bilinear interpolation, we look for the nearest indices below and
|
||||||
// above proj
|
// above proj
|
||||||
|
@ -2926,6 +2955,176 @@ static Value BilinearInterpolate(OpBuilder &b,
|
||||||
return b.create<arith::AddFOp>(loc, left, right);
|
return b.create<arith::AddFOp>(loc, left, right);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Value bicubicInterpolate(OpBuilder &b,
|
||||||
|
Aten__InterpolateSizeListScaleListOp op,
|
||||||
|
Location loc, SmallVector<Value> outputSizes,
|
||||||
|
Value input, SmallVector<Value> inputSizes,
|
||||||
|
SmallVector<Value> scaleValues,
|
||||||
|
std::string coordStr) {
|
||||||
|
unsigned dimOffset = 2;
|
||||||
|
auto inputType = cast<RankedTensorType>(input.getType());
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
Value inputFPH =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
|
||||||
|
Value inputFPW =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
|
||||||
|
|
||||||
|
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
|
||||||
|
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
|
||||||
|
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
|
||||||
|
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
|
||||||
|
Value cstThreeFloat =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
|
||||||
|
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
|
||||||
|
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
|
||||||
|
Value cstEightFloat =
|
||||||
|
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
|
||||||
|
|
||||||
|
// (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
|
||||||
|
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
|
||||||
|
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
|
||||||
|
Value xDistanceCubed =
|
||||||
|
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
|
||||||
|
|
||||||
|
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
|
||||||
|
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
|
||||||
|
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
|
||||||
|
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
|
||||||
|
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
|
||||||
|
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
|
||||||
|
|
||||||
|
return lessEqualOne;
|
||||||
|
};
|
||||||
|
|
||||||
|
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
|
||||||
|
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
|
||||||
|
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
|
||||||
|
Value xDistanceCubed =
|
||||||
|
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
|
||||||
|
// a|x|^3
|
||||||
|
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
|
||||||
|
|
||||||
|
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
|
||||||
|
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
|
||||||
|
// a|x|^3 - 5a|x|^2
|
||||||
|
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
|
||||||
|
|
||||||
|
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
|
||||||
|
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
|
||||||
|
// a|x|^3 - 5a|x|^2 + 8a|x|
|
||||||
|
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
|
||||||
|
|
||||||
|
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
|
||||||
|
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a
|
||||||
|
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
|
||||||
|
return lessThanTwo;
|
||||||
|
};
|
||||||
|
|
||||||
|
bool alignCornersBool;
|
||||||
|
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
|
||||||
|
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (unsigned i = 0; i < inputRank; i++) {
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value> proj;
|
||||||
|
|
||||||
|
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
|
||||||
|
scaleValues, coordStr, alignCornersBool, indices,
|
||||||
|
false);
|
||||||
|
|
||||||
|
// get the nearest neighbors of proj
|
||||||
|
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
|
||||||
|
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
|
||||||
|
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
|
||||||
|
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
|
||||||
|
|
||||||
|
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
|
||||||
|
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
|
||||||
|
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
|
||||||
|
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
|
||||||
|
|
||||||
|
// calculate the distance of nearest neighbors x and y to proj
|
||||||
|
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
|
||||||
|
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
|
||||||
|
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
|
||||||
|
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
|
||||||
|
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
|
||||||
|
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
|
||||||
|
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
|
||||||
|
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
|
||||||
|
|
||||||
|
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
|
||||||
|
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
|
||||||
|
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
|
||||||
|
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
|
||||||
|
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
|
||||||
|
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
|
||||||
|
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
|
||||||
|
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
|
||||||
|
|
||||||
|
SmallVector<Value> y{y_2, y_1, y1, y2};
|
||||||
|
SmallVector<Value> x{x_2, x_1, x1, x2};
|
||||||
|
|
||||||
|
SmallVector<Value> wys{
|
||||||
|
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
|
||||||
|
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
|
||||||
|
SmallVector<Value> wxs{
|
||||||
|
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
|
||||||
|
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};
|
||||||
|
|
||||||
|
// clip the nearest neighbors points to inside the original image
|
||||||
|
for (int k = 0; k < 4; k++) {
|
||||||
|
Value yClipped = b.create<arith::MaximumFOp>(loc, y[k], zero);
|
||||||
|
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
|
||||||
|
yClipped = b.create<arith::MinimumFOp>(loc, yClipped, inputHSubOne);
|
||||||
|
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yClipped);
|
||||||
|
y[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
|
||||||
|
|
||||||
|
Value xClipped = b.create<arith::MaximumFOp>(loc, x[k], zero);
|
||||||
|
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
|
||||||
|
xClipped = b.create<arith::MinimumFOp>(loc, xClipped, inputWSubOne);
|
||||||
|
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xClipped);
|
||||||
|
x[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
|
||||||
|
}
|
||||||
|
// 1. Compute x_original and y_original (proj)
|
||||||
|
// 2. Compute nearest x and y neighbors
|
||||||
|
// 3. Compute Wx Wy
|
||||||
|
// 4. Extract inputs at nearest neighbors (inputExtracts)
|
||||||
|
// 5. Compute weighted sum (yield this)
|
||||||
|
|
||||||
|
// 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
|
||||||
|
// 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
|
||||||
|
// Sum_x is over 4 nearest x neighbors (similar for Sum_y)
|
||||||
|
// f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
|
||||||
|
// * W(y_original - y)
|
||||||
|
Value fxy = zero;
|
||||||
|
|
||||||
|
for (int j = 0; j < 4; j++) {
|
||||||
|
Value wy = wys[j];
|
||||||
|
Value xInterpy = zero;
|
||||||
|
|
||||||
|
indices[dimOffset] = y[j];
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
Value wx = wxs[i];
|
||||||
|
|
||||||
|
indices[dimOffset + 1] = x[i];
|
||||||
|
|
||||||
|
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
|
||||||
|
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
|
||||||
|
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
|
||||||
|
}
|
||||||
|
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
|
||||||
|
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
|
||||||
|
}
|
||||||
|
|
||||||
|
return fxy;
|
||||||
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class ConvertInterpolateOp
|
class ConvertInterpolateOp
|
||||||
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
|
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
|
||||||
|
@ -2941,7 +3140,8 @@ public:
|
||||||
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
|
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
|
||||||
// op with the non-standard mode="bilinear_asymmetric".
|
// op with the non-standard mode="bilinear_asymmetric".
|
||||||
matchPattern(op.getMode(), m_TorchConstantStr(mode));
|
matchPattern(op.getMode(), m_TorchConstantStr(mode));
|
||||||
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
|
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
|
||||||
|
mode.substr(0, 5) != "cubic") {
|
||||||
return failure();
|
return failure();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3023,13 +3223,18 @@ public:
|
||||||
(mode.find(",") == std::string::npos)
|
(mode.find(",") == std::string::npos)
|
||||||
? ""
|
? ""
|
||||||
: mode.substr(mode.find(",") + 1);
|
: mode.substr(mode.find(",") + 1);
|
||||||
retVal = NearestInterpolate(
|
retVal = nearestInterpolate(
|
||||||
b, loc, outputSizeIntValues, input, inputSizes,
|
b, loc, outputSizeIntValues, input, inputSizes,
|
||||||
ScaleFactorFloatValues, coordTfMode, nearestMode);
|
ScaleFactorFloatValues, coordTfMode, nearestMode);
|
||||||
} else if (mode.substr(0, 8) == "bilinear") {
|
} else if (mode.substr(0, 8) == "bilinear") {
|
||||||
retVal = BilinearInterpolate(
|
retVal = bilinearInterpolate(
|
||||||
b, op, loc, outputSizeIntValues, input, inputSizes,
|
b, op, loc, outputSizeIntValues, input, inputSizes,
|
||||||
ScaleFactorFloatValues, mode.substr(8));
|
ScaleFactorFloatValues, mode.substr(8));
|
||||||
|
} else if (mode.substr(0, 5) == "cubic") {
|
||||||
|
|
||||||
|
retVal = bicubicInterpolate(
|
||||||
|
b, op, loc, outputSizeIntValues, input, inputSizes,
|
||||||
|
ScaleFactorFloatValues, mode.substr(5));
|
||||||
}
|
}
|
||||||
b.create<linalg::YieldOp>(loc, retVal);
|
b.create<linalg::YieldOp>(loc, retVal);
|
||||||
})
|
})
|
||||||
|
|
|
@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
|
||||||
// CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
|
// CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
|
||||||
// CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32
|
// CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32
|
||||||
// CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32
|
// CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32
|
||||||
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32
|
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32
|
||||||
// CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32
|
// CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32
|
||||||
// CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32
|
// CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32
|
||||||
// CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32
|
// CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32
|
||||||
// CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32
|
// CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32
|
||||||
// CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
|
// CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
|
||||||
// CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
|
// CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
|
||||||
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32
|
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32
|
||||||
// CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64
|
// CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64
|
||||||
// CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index
|
// CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index
|
||||||
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32>
|
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32>
|
||||||
|
@ -304,4 +304,51 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens
|
||||||
return %5 : !torch.vtensor<[?,?,?],f32>
|
return %5 : !torch.vtensor<[?,?,?],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_resize_sizes_cubic
|
||||||
|
func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
||||||
|
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19
|
||||||
|
: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32
|
||||||
|
// CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32
|
||||||
|
// CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32
|
||||||
|
// CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32
|
||||||
|
// CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32
|
||||||
|
// CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32
|
||||||
|
// CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32
|
||||||
|
// CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32
|
||||||
|
// CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32
|
||||||
|
// CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32
|
||||||
|
// CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32
|
||||||
|
// CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32
|
||||||
|
// CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32
|
||||||
|
// CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32
|
||||||
|
// CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32
|
||||||
|
// CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32
|
||||||
|
// CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32
|
||||||
|
// CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32
|
||||||
|
// CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32
|
||||||
|
// CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32
|
||||||
|
// CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32
|
||||||
|
%none = torch.constant.none
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%str = torch.constant.str "cubic"
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %5 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
Loading…
Reference in New Issue