OnnxToTorch bicubic interpolation (#3802)

(https://github.com/nod-ai/SHARK-TestSuite/pull/391)
Repro (using SHARK TestSuite):
1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test`

---------

Co-authored-by: zjgarvey <zjgarvey@gmail.com>
pull/3759/head
aldesilv 2024-11-12 10:54:29 -08:00 committed by GitHub
parent 17c1985c4d
commit 889a836b3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 292 additions and 31 deletions

View File

@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
int64_t antialias, exclude_outside;
float extrapolation_value;
float extrapolation_value, cubic_coeff_a;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
binder.f32FloatAttr(extrapolation_value, "extrapolation_value",
0.0) ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode",
"round_prefer_floor"))
"round_prefer_floor") ||
binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75))
return failure();
if (antialias != 0) {
return rewriter.notifyMatchFailure(
@ -2976,6 +2977,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
"except asymmetric and half_pixel");
}
if (mode == "cubic" && cubic_coeff_a != -0.75) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: cubic coeff must be -0.75");
}
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();
@ -2991,8 +2997,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
std::string modeStr = "cubic";
if (coordTfMode != "half_pixel")
modeStr = modeStr + "_" + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
// supported modes:
// bilinear (half_pixel), bilinear with align_corners,

View File

@ -2683,7 +2683,7 @@ public:
};
} // namespace
static Value NearestInterpolate(OpBuilder &b, Location loc,
static Value nearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
@ -2771,12 +2771,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
return retVal;
}
static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
static SmallVector<Value> coordinateTransform(
OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc,
SmallVector<Value> outputSizes, Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues, std::string coordStr, bool alignCornersBool,
SmallVector<Value> indices, bool clip) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
@ -2785,15 +2785,7 @@ static Value BilinearInterpolate(OpBuilder &b,
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
SmallVector<Value> proj;
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
@ -2856,6 +2848,7 @@ static Value BilinearInterpolate(OpBuilder &b,
outputSizeFP, cstOneFloat);
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
}
if (clip) {
// preClip is the fp position inside the input image to extract from.
// clip to [0,inf)
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
@ -2863,6 +2856,42 @@ static Value BilinearInterpolate(OpBuilder &b,
// clip to [0,length_original - 1].
// proj is properly within the input image.
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
} else {
proj.push_back(preClip);
}
}
return proj;
}
static Value bilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
SmallVector<Value> proj, high, low, highFP, lowFP;
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices,
true);
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// for bilinear interpolation, we look for the nearest indices below and
// above proj
@ -2926,6 +2955,176 @@ static Value BilinearInterpolate(OpBuilder &b,
return b.create<arith::AddFOp>(loc, left, right);
}
static Value bicubicInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = cast<RankedTensorType>(input.getType());
auto inputRank = inputType.getRank();
Value inputFPH =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[0]);
Value inputFPW =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[1]);
Value a = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(-0.75));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstTwoFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(2.0));
Value cstThreeFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(3.0));
Value cstFourFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(4.0));
Value cstFiveFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(5.0));
Value cstEightFloat =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(8.0));
// (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1)
auto WeightLessThanEqualOne = [&](Value xDistance) -> Value {
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
Value xDistanceCubed =
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
Value lessEqualOne = b.create<arith::AddFOp>(loc, a, cstTwoFloat);
lessEqualOne = b.create<arith::MulFOp>(loc, xDistanceCubed, lessEqualOne);
Value aPlusThree = b.create<arith::AddFOp>(loc, a, cstThreeFloat);
aPlusThree = b.create<arith::MulFOp>(loc, xDistanceSquared, aPlusThree);
lessEqualOne = b.create<arith::SubFOp>(loc, lessEqualOne, aPlusThree);
lessEqualOne = b.create<arith::AddFOp>(loc, lessEqualOne, cstOneFloat);
return lessEqualOne;
};
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2)
auto WeightLessThanTwo = [&](Value xDistance) -> Value {
Value xDistanceSquared = b.create<arith::MulFOp>(loc, xDistance, xDistance);
Value xDistanceCubed =
b.create<arith::MulFOp>(loc, xDistanceSquared, xDistance);
// a|x|^3
Value lessThanTwo = b.create<arith::MulFOp>(loc, xDistanceCubed, a);
Value fiveA = b.create<arith::MulFOp>(loc, xDistanceSquared, a);
fiveA = b.create<arith::MulFOp>(loc, fiveA, cstFiveFloat);
// a|x|^3 - 5a|x|^2
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fiveA);
Value eightA = b.create<arith::MulFOp>(loc, a, xDistance);
eightA = b.create<arith::MulFOp>(loc, eightA, cstEightFloat);
// a|x|^3 - 5a|x|^2 + 8a|x|
lessThanTwo = b.create<arith::AddFOp>(loc, eightA, lessThanTwo);
Value fourA = b.create<arith::MulFOp>(loc, a, cstFourFloat);
// a|x|^3 - 5a|x|^2 + 8a|x| - 4a
lessThanTwo = b.create<arith::SubFOp>(loc, lessThanTwo, fourA);
return lessThanTwo;
};
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
SmallVector<Value> proj;
proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes,
scaleValues, coordStr, alignCornersBool, indices,
false);
// get the nearest neighbors of proj
Value x1 = b.create<math::CeilOp>(loc, proj[1]);
Value x_1 = b.create<arith::SubFOp>(loc, x1, cstOneFloat);
Value x_2 = b.create<arith::SubFOp>(loc, x_1, cstOneFloat);
Value x2 = b.create<arith::AddFOp>(loc, x1, cstOneFloat);
Value y1 = b.create<math::CeilOp>(loc, proj[0]);
Value y_1 = b.create<arith::SubFOp>(loc, y1, cstOneFloat);
Value y_2 = b.create<arith::SubFOp>(loc, y_1, cstOneFloat);
Value y2 = b.create<arith::AddFOp>(loc, y1, cstOneFloat);
// calculate the distance of nearest neighbors x and y to proj
Value y2Distance = b.create<arith::SubFOp>(loc, proj[0], y2);
y2Distance = b.create<math::AbsFOp>(loc, y2Distance);
Value y1Distance = b.create<arith::SubFOp>(loc, proj[0], y1);
y1Distance = b.create<math::AbsFOp>(loc, y1Distance);
Value y_1Distance = b.create<arith::SubFOp>(loc, proj[0], y_1);
y_1Distance = b.create<math::AbsFOp>(loc, y_1Distance);
Value y_2Distance = b.create<arith::SubFOp>(loc, proj[0], y_2);
y_2Distance = b.create<math::AbsFOp>(loc, y_2Distance);
Value x2Distance = b.create<arith::SubFOp>(loc, proj[1], x2);
x2Distance = b.create<math::AbsFOp>(loc, x2Distance);
Value x1Distance = b.create<arith::SubFOp>(loc, proj[1], x1);
x1Distance = b.create<math::AbsFOp>(loc, x1Distance);
Value x_1Distance = b.create<arith::SubFOp>(loc, proj[1], x_1);
x_1Distance = b.create<math::AbsFOp>(loc, x_1Distance);
Value x_2Distance = b.create<arith::SubFOp>(loc, proj[1], x_2);
x_2Distance = b.create<math::AbsFOp>(loc, x_2Distance);
SmallVector<Value> y{y_2, y_1, y1, y2};
SmallVector<Value> x{x_2, x_1, x1, x2};
SmallVector<Value> wys{
WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance),
WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)};
SmallVector<Value> wxs{
WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance),
WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)};
// clip the nearest neighbors points to inside the original image
for (int k = 0; k < 4; k++) {
Value yClipped = b.create<arith::MaximumFOp>(loc, y[k], zero);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputFPH, cstOneFloat);
yClipped = b.create<arith::MinimumFOp>(loc, yClipped, inputHSubOne);
Value yInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yClipped);
y[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yInt);
Value xClipped = b.create<arith::MaximumFOp>(loc, x[k], zero);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputFPW, cstOneFloat);
xClipped = b.create<arith::MinimumFOp>(loc, xClipped, inputWSubOne);
Value xInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xClipped);
x[k] = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xInt);
}
// 1. Compute x_original and y_original (proj)
// 2. Compute nearest x and y neighbors
// 3. Compute Wx Wy
// 4. Extract inputs at nearest neighbors (inputExtracts)
// 5. Compute weighted sum (yield this)
// 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original
// 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original
// Sum_x is over 4 nearest x neighbors (similar for Sum_y)
// f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y]
// * W(y_original - y)
Value fxy = zero;
for (int j = 0; j < 4; j++) {
Value wy = wys[j];
Value xInterpy = zero;
indices[dimOffset] = y[j];
for (int i = 0; i < 4; i++) {
Value wx = wxs[i];
indices[dimOffset + 1] = x[i];
Value p = b.create<tensor::ExtractOp>(loc, input, indices);
Value wxp = b.create<arith::MulFOp>(loc, wx, p);
xInterpy = b.create<arith::AddFOp>(loc, xInterpy, wxp);
}
Value wyXInterpy = b.create<arith::MulFOp>(loc, wy, xInterpy);
fxy = b.create<arith::AddFOp>(loc, fxy, wyXInterpy);
}
return fxy;
}
namespace {
class ConvertInterpolateOp
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
@ -2941,7 +3140,8 @@ public:
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
// op with the non-standard mode="bilinear_asymmetric".
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" &&
mode.substr(0, 5) != "cubic") {
return failure();
}
@ -3023,13 +3223,18 @@ public:
(mode.find(",") == std::string::npos)
? ""
: mode.substr(mode.find(",") + 1);
retVal = NearestInterpolate(
retVal = nearestInterpolate(
b, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, coordTfMode, nearestMode);
} else if (mode.substr(0, 8) == "bilinear") {
retVal = BilinearInterpolate(
retVal = bilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(8));
} else if (mode.substr(0, 5) == "cubic") {
retVal = bicubicInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(5));
}
b.create<linalg::YieldOp>(loc, retVal);
})

View File

@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1:
// CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32
// CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32
// CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32
// CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32
// CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32
// CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32
// CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32
// CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
// CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32
// CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32
// CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64
// CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32>
@ -304,4 +304,51 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens
return %5 : !torch.vtensor<[?,?,?],f32>
}
// CHECK-LABEL: func.func @test_resize_sizes_cubic
func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19
: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32
// CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32
// CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32
// CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32
// CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32
// CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32
// CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32
// CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32
// CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32
// CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32
// CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32
// CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32
// CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32
// CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32
// CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32
// CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32
// CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32
// CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32
// CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32
// CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32
// CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32
// CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32
// CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32
// CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32
// CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32
// CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%str = torch.constant.str "cubic"
%int2 = torch.constant.int 2
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%int3 = torch.constant.int 3
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}
// -----