From 889a836b3dc58c68f2a0a5fc08512e2a4b56246a Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Tue, 12 Nov 2024 10:54:29 -0800 Subject: [PATCH] OnnxToTorch bicubic interpolation (#3802) (https://github.com/nod-ai/SHARK-TestSuite/pull/391) Repro (using SHARK TestSuite): 1. `python run.py --torchtolinalg -m cl-onnx-iree -t cubic_test` --------- Co-authored-by: zjgarvey --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 17 +- .../TorchToLinalg/Uncategorized.cpp | 255 ++++++++++++++++-- test/Conversion/TorchToLinalg/resize.mlir | 51 +++- 3 files changed, 292 insertions(+), 31 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ea2e0452e..1793af959 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2922,7 +2922,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector operands; std::string mode, nearest_mode, coordTfMode; int64_t antialias, exclude_outside; - float extrapolation_value; + float extrapolation_value, cubic_coeff_a; Value noneVal = rewriter.create(binder.getLoc()); if (auto attr = binder.op->getAttr("torch.onnx.axes")) { @@ -2947,7 +2947,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.f32FloatAttr(extrapolation_value, "extrapolation_value", 0.0) || binder.customOpNameStringAttr(nearest_mode, "nearest_mode", - "round_prefer_floor")) + "round_prefer_floor") || + binder.f32FloatAttr(cubic_coeff_a, "cubic_coeff_a", -0.75)) return failure(); if (antialias != 0) { return rewriter.notifyMatchFailure( @@ -2976,6 +2977,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( "except asymmetric and half_pixel"); } + if (mode == "cubic" && cubic_coeff_a != -0.75) { + return rewriter.notifyMatchFailure( + binder.op, "unimplemented: cubic coeff must be -0.75"); + } + unsigned rank = dyn_cast(operands[0].getType()) .getSizes() .size(); @@ -2991,8 +2997,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value alignCorners = coordTfMode == "align_corners" ? cstTrue : cstFalse; if (mode == "cubic") { - return rewriter.notifyMatchFailure(binder.op, - "unimplemented: bicubic mode"); + std::string modeStr = "cubic"; + if (coordTfMode != "half_pixel") + modeStr = modeStr + "_" + coordTfMode; + modeStrValue = + rewriter.create(binder.getLoc(), modeStr); } // supported modes: // bilinear (half_pixel), bilinear with align_corners, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index c129c9614..35e4144f3 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2683,7 +2683,7 @@ public: }; } // namespace -static Value NearestInterpolate(OpBuilder &b, Location loc, +static Value nearestInterpolate(OpBuilder &b, Location loc, SmallVector outputSizes, Value input, SmallVector inputSizes, SmallVector scaleValues, @@ -2771,12 +2771,12 @@ static Value NearestInterpolate(OpBuilder &b, Location loc, return retVal; } -static Value BilinearInterpolate(OpBuilder &b, - Aten__InterpolateSizeListScaleListOp op, - Location loc, SmallVector outputSizes, - Value input, SmallVector inputSizes, - SmallVector scaleValues, - std::string coordStr) { +static SmallVector coordinateTransform( + OpBuilder &b, Aten__InterpolateSizeListScaleListOp op, Location loc, + SmallVector outputSizes, Value input, SmallVector inputSizes, + SmallVector scaleValues, std::string coordStr, bool alignCornersBool, + SmallVector indices, bool clip) { + unsigned dimOffset = 2; auto inputType = cast(input.getType()); auto inputRank = inputType.getRank(); @@ -2785,15 +2785,7 @@ static Value BilinearInterpolate(OpBuilder &b, Value cstHalf = b.create(loc, b.getF32FloatAttr(0.5)); Value zero = b.create(loc, b.getF32FloatAttr(0.0)); - bool alignCornersBool; - matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); - - SmallVector indices; - for (unsigned i = 0; i < inputRank; i++) { - indices.push_back(b.create(loc, i)); - } - - SmallVector proj, projEps, high, low, highFP, lowFP; + SmallVector proj; for (unsigned i = 0; i < inputRank - dimOffset; i++) { // length_original Value inputFP = @@ -2856,13 +2848,50 @@ static Value BilinearInterpolate(OpBuilder &b, outputSizeFP, cstOneFloat); preClip = b.create(loc, cmp, zero, preClip); } - // preClip is the fp position inside the input image to extract from. - // clip to [0,inf) - Value max = b.create(loc, preClip, zero); + if (clip) { + // preClip is the fp position inside the input image to extract from. + // clip to [0,inf) + Value max = b.create(loc, preClip, zero); + Value inputSubOne = b.create(loc, inputFP, cstOneFloat); + // clip to [0,length_original - 1]. + // proj is properly within the input image. + proj.push_back(b.create(loc, max, inputSubOne)); + } else { + proj.push_back(preClip); + } + } + return proj; +} + +static Value bilinearInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj, high, low, highFP, lowFP; + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + true); + for (unsigned i = 0; i < inputRank - dimOffset; i++) { + // length_original + Value inputFP = + b.create(loc, b.getF32Type(), inputSizes[i]); Value inputSubOne = b.create(loc, inputFP, cstOneFloat); - // clip to [0,length_original - 1]. - // proj is properly within the input image. - proj.push_back(b.create(loc, max, inputSubOne)); // for bilinear interpolation, we look for the nearest indices below and // above proj @@ -2926,6 +2955,176 @@ static Value BilinearInterpolate(OpBuilder &b, return b.create(loc, left, right); } +static Value bicubicInterpolate(OpBuilder &b, + Aten__InterpolateSizeListScaleListOp op, + Location loc, SmallVector outputSizes, + Value input, SmallVector inputSizes, + SmallVector scaleValues, + std::string coordStr) { + unsigned dimOffset = 2; + auto inputType = cast(input.getType()); + auto inputRank = inputType.getRank(); + + Value inputFPH = + b.create(loc, b.getF32Type(), inputSizes[0]); + Value inputFPW = + b.create(loc, b.getF32Type(), inputSizes[1]); + + Value a = b.create(loc, b.getF32FloatAttr(-0.75)); + Value zero = b.create(loc, b.getF32FloatAttr(0.0)); + Value cstOneFloat = b.create(loc, b.getF32FloatAttr(1.0)); + Value cstTwoFloat = b.create(loc, b.getF32FloatAttr(2.0)); + Value cstThreeFloat = + b.create(loc, b.getF32FloatAttr(3.0)); + Value cstFourFloat = b.create(loc, b.getF32FloatAttr(4.0)); + Value cstFiveFloat = b.create(loc, b.getF32FloatAttr(5.0)); + Value cstEightFloat = + b.create(loc, b.getF32FloatAttr(8.0)); + + // (a+2)|x|^3 - (a+3)|x|^2 + 1 for xDistance (|x| <= 1) + auto WeightLessThanEqualOne = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + + Value lessEqualOne = b.create(loc, a, cstTwoFloat); + lessEqualOne = b.create(loc, xDistanceCubed, lessEqualOne); + Value aPlusThree = b.create(loc, a, cstThreeFloat); + aPlusThree = b.create(loc, xDistanceSquared, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, aPlusThree); + lessEqualOne = b.create(loc, lessEqualOne, cstOneFloat); + + return lessEqualOne; + }; + + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a for xDistance (1 < |x| < 2) + auto WeightLessThanTwo = [&](Value xDistance) -> Value { + Value xDistanceSquared = b.create(loc, xDistance, xDistance); + Value xDistanceCubed = + b.create(loc, xDistanceSquared, xDistance); + // a|x|^3 + Value lessThanTwo = b.create(loc, xDistanceCubed, a); + + Value fiveA = b.create(loc, xDistanceSquared, a); + fiveA = b.create(loc, fiveA, cstFiveFloat); + // a|x|^3 - 5a|x|^2 + lessThanTwo = b.create(loc, lessThanTwo, fiveA); + + Value eightA = b.create(loc, a, xDistance); + eightA = b.create(loc, eightA, cstEightFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| + lessThanTwo = b.create(loc, eightA, lessThanTwo); + + Value fourA = b.create(loc, a, cstFourFloat); + // a|x|^3 - 5a|x|^2 + 8a|x| - 4a + lessThanTwo = b.create(loc, lessThanTwo, fourA); + return lessThanTwo; + }; + + bool alignCornersBool; + matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool)); + + SmallVector indices; + for (unsigned i = 0; i < inputRank; i++) { + indices.push_back(b.create(loc, i)); + } + + SmallVector proj; + + proj = coordinateTransform(b, op, loc, outputSizes, input, inputSizes, + scaleValues, coordStr, alignCornersBool, indices, + false); + + // get the nearest neighbors of proj + Value x1 = b.create(loc, proj[1]); + Value x_1 = b.create(loc, x1, cstOneFloat); + Value x_2 = b.create(loc, x_1, cstOneFloat); + Value x2 = b.create(loc, x1, cstOneFloat); + + Value y1 = b.create(loc, proj[0]); + Value y_1 = b.create(loc, y1, cstOneFloat); + Value y_2 = b.create(loc, y_1, cstOneFloat); + Value y2 = b.create(loc, y1, cstOneFloat); + + // calculate the distance of nearest neighbors x and y to proj + Value y2Distance = b.create(loc, proj[0], y2); + y2Distance = b.create(loc, y2Distance); + Value y1Distance = b.create(loc, proj[0], y1); + y1Distance = b.create(loc, y1Distance); + Value y_1Distance = b.create(loc, proj[0], y_1); + y_1Distance = b.create(loc, y_1Distance); + Value y_2Distance = b.create(loc, proj[0], y_2); + y_2Distance = b.create(loc, y_2Distance); + + Value x2Distance = b.create(loc, proj[1], x2); + x2Distance = b.create(loc, x2Distance); + Value x1Distance = b.create(loc, proj[1], x1); + x1Distance = b.create(loc, x1Distance); + Value x_1Distance = b.create(loc, proj[1], x_1); + x_1Distance = b.create(loc, x_1Distance); + Value x_2Distance = b.create(loc, proj[1], x_2); + x_2Distance = b.create(loc, x_2Distance); + + SmallVector y{y_2, y_1, y1, y2}; + SmallVector x{x_2, x_1, x1, x2}; + + SmallVector wys{ + WeightLessThanTwo(y_2Distance), WeightLessThanEqualOne(y_1Distance), + WeightLessThanEqualOne(y1Distance), WeightLessThanTwo(y2Distance)}; + SmallVector wxs{ + WeightLessThanTwo(x_2Distance), WeightLessThanEqualOne(x_1Distance), + WeightLessThanEqualOne(x1Distance), WeightLessThanTwo(x2Distance)}; + + // clip the nearest neighbors points to inside the original image + for (int k = 0; k < 4; k++) { + Value yClipped = b.create(loc, y[k], zero); + Value inputHSubOne = b.create(loc, inputFPH, cstOneFloat); + yClipped = b.create(loc, yClipped, inputHSubOne); + Value yInt = b.create(loc, b.getI64Type(), yClipped); + y[k] = b.create(loc, b.getIndexType(), yInt); + + Value xClipped = b.create(loc, x[k], zero); + Value inputWSubOne = b.create(loc, inputFPW, cstOneFloat); + xClipped = b.create(loc, xClipped, inputWSubOne); + Value xInt = b.create(loc, b.getI64Type(), xClipped); + x[k] = b.create(loc, b.getIndexType(), xInt); + } + // 1. Compute x_original and y_original (proj) + // 2. Compute nearest x and y neighbors + // 3. Compute Wx Wy + // 4. Extract inputs at nearest neighbors (inputExtracts) + // 5. Compute weighted sum (yield this) + + // 4 nearest x neighbors : [x_2, x_1, x1, x2] of x_original + // 4 nearest y neighbors : [y_2, y_1, y1, y2] of y_original + // Sum_x is over 4 nearest x neighbors (similar for Sum_y) + // f(x_original, y_original) = Sum_y Sum_x W(x_original - x)*input[x,y] + // * W(y_original - y) + Value fxy = zero; + + for (int j = 0; j < 4; j++) { + Value wy = wys[j]; + Value xInterpy = zero; + + indices[dimOffset] = y[j]; + + for (int i = 0; i < 4; i++) { + Value wx = wxs[i]; + + indices[dimOffset + 1] = x[i]; + + Value p = b.create(loc, input, indices); + + Value wxp = b.create(loc, wx, p); + xInterpy = b.create(loc, xInterpy, wxp); + } + Value wyXInterpy = b.create(loc, wy, xInterpy); + fxy = b.create(loc, fxy, wyXInterpy); + } + + return fxy; +} + namespace { class ConvertInterpolateOp : public OpConversionPattern { @@ -2941,7 +3140,8 @@ public: // coordinate_transformation_mode="asymmetric" will lower to an interpolate // op with the non-standard mode="bilinear_asymmetric". matchPattern(op.getMode(), m_TorchConstantStr(mode)); - if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") { + if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest" && + mode.substr(0, 5) != "cubic") { return failure(); } @@ -3023,13 +3223,18 @@ public: (mode.find(",") == std::string::npos) ? "" : mode.substr(mode.find(",") + 1); - retVal = NearestInterpolate( + retVal = nearestInterpolate( b, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, coordTfMode, nearestMode); } else if (mode.substr(0, 8) == "bilinear") { - retVal = BilinearInterpolate( + retVal = bilinearInterpolate( b, op, loc, outputSizeIntValues, input, inputSizes, ScaleFactorFloatValues, mode.substr(8)); + } else if (mode.substr(0, 5) == "cubic") { + + retVal = bicubicInterpolate( + b, op, loc, outputSizeIntValues, input, inputSizes, + ScaleFactorFloatValues, mode.substr(5)); } b.create(loc, retVal); }) diff --git a/test/Conversion/TorchToLinalg/resize.mlir b/test/Conversion/TorchToLinalg/resize.mlir index 7976b1ad8..1dfe45492 100644 --- a/test/Conversion/TorchToLinalg/resize.mlir +++ b/test/Conversion/TorchToLinalg/resize.mlir @@ -21,14 +21,14 @@ func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: // CHECK-DAG: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32 // CHECK-DAG: %[[x26:.*]] = arith.subf %[[x25]], %[[cst_4]] : f32 // CHECK-DAG: %[[x27:.*]] = arith.maximumf %[[x26]], %[[cst_5]] : f32 - // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %[[cst]] : f32 + // CHECK-DAG: %[[x28:.*]] = arith.subf %[[x19]], %cst_4 : f32 // CHECK-DAG: %[[x29:.*]] = arith.minimumf %[[x27]], %[[x28]] : f32 // CHECK-DAG: %[[x30:.*]] = math.floor %[[x29]] : f32 // CHECK-DAG: %[[x31:.*]] = arith.addf %[[cst]], %[[x29]] : f32 // CHECK-DAG: %[[x32:.*]] = math.floor %[[x31]] : f32 // CHECK-DAG: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64 // CHECK-DAG: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index - // CHECK-DAG: %[[x35:.*]] = arith.minimumf %[[x31]], %[[x28]] : f32 + // CHECK-DAG: %[[x35:.*]] = arith.minimumf %44, %42 : f32 // CHECK-DAG: %[[x36:.*]] = arith.fptosi %[[x35]] : f32 to i64 // CHECK-DAG: %[[x37:.*]] = arith.index_cast %[[x36]] : i64 to index // CHECK: %[[extracted:.*]] = tensor.extract %[[x0]][%[[x15]], %[[x16]], %[[x34]], %[[low:.*]]] : tensor<1x1x2x4xf32> @@ -304,4 +304,51 @@ func.func @test_resize_nearest_half_pixel_round_prefer_floor(%arg0: !torch.vtens return %5 : !torch.vtensor<[?,?,?],f32> } +// CHECK-LABEL: func.func @test_resize_sizes_cubic +func.func @test_resize_sizes_cubic(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4] +,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 +: si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[x1:.*]] = math.ceil %36 : f32 + // CHECK-DAG: %[[x_1:.*]] = arith.subf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[x_2:.*]] = arith.subf %[[x_1]], %cst_5 : f32 + // CHECK-DAG: %[[x2:.*]] = arith.addf %[[x1]], %cst_5 : f32 + // CHECK-DAG: %[[y1:.*]] = math.ceil %28 : f32 + // CHECK-DAG: %[[y_1:.*]] = arith.subf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y_2:.*]] = arith.subf %[[y_1]], %cst_5 : f32 + // CHECK-DAG: %[[y2:.*]] = arith.addf %[[y1]], %cst_5 : f32 + // CHECK-DAG: %[[y2D:.*]] = arith.subf %28, %[[y2]] : f32 + // CHECK-DAG: %[[y2Dist:.*]] = math.absf %[[y2D]] : f32 + // CHECK-DAG: %[[y1D:.*]] = arith.subf %28, %[[y1]] : f32 + // CHECK-DAG: %[[y1Dist:.*]] = math.absf %[[y1D]] : f32 + // CHECK-DAG: %[[y_1D:.*]] = arith.subf %28, %[[y_1]] : f32 + // CHECK-DAG: %[[y_1Dist:.*]] = math.absf %[[y_1D]] : f32 + // CHECK-DAG: %[[y_2D:.*]] = arith.subf %28, %[[y_2]] : f32 + // CHECK-DAG: %[[y_2Dist:.*]] = math.absf %[[y_2D]] : f32 + // CHECK-DAG: %[[x2D:.*]] = arith.subf %36, %[[x2]] : f32 + // CHECK-DAG: %[[x2Dist:.*]] = math.absf %[[x2D]] : f32 + // CHECK-DAG: %[[x1D:.*]] = arith.subf %36, %[[x1]] : f32 + // CHECK-DAG: %[[x1Dist:.*]] = math.absf %[[x1D]] : f32 + // CHECK-DAG: %[[x_1D:.*]] = arith.subf %36, %[[x_1]] : f32 + // CHECK-DAG: %[[x_1Dist:.*]] = math.absf %[[x_1D]] : f32 + // CHECK-DAG: %[[x_2D:.*]] = arith.subf %36, %[[x_2]] : f32 + // CHECK-DAG: %[[x_2Dist:.*]] = math.absf %[[x_2D]] : f32 + // CHECK-DAG: %[[distSQ:.*]] = arith.mulf %52, %52 : f32 + // CHECK-DAG: %[[distCubed:.*]] = arith.mulf %[[distSQ]], %52 : f32 + %none = torch.constant.none + %none_0 = torch.constant.none + %int0 = torch.constant.int 0 + %false = torch.constant.bool false + %true = torch.constant.bool true + %str = torch.constant.str "cubic" + %int2 = torch.constant.int 2 + %0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + %int3 = torch.constant.int 3 + %2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + %3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + %4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + %5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %5 : !torch.vtensor<[?,?,?,?],f32> +} + // -----