Modifies onnx resize lowering to fix numerical issues (#3381)

Updates:

- some unsupported modes are now going to report a match failure for
unsupported coordinate transformation modes.
- fixes a bug that was introduced in the last patch for resize (my
bad...)
- uses actual x and y coordinates for computing weights in bilinear
interpolation (rather than eps modified values)
- slightly simplifies the bilinear interpolation payload for readability
and performance
- passes coordinate transformation mode information from an onnx.Resize
op to the mode string for the aten._interpolate op. This allows us to
perform custom logic in the torch->linalg lowering to support
onnx.Resize options without losing the default behaviors of the
interpolate op.
pull/3405/head
zjgarvey 2024-05-30 19:34:37 -05:00 committed by GitHub
parent d7b8f00d01
commit 074098d20c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 304 additions and 242 deletions

View File

@ -2784,12 +2784,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
return failure();
if (coordTfMode == "tf_crop_and_resize")
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: coordinate transformation mode: "
"tf_crop_and_resize");
if (mode == "nearest" && nearest_mode != "floor") {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for nearest_mode "
"except floor");
}
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
.getSizes()
.size();
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
@ -2851,36 +2857,54 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
}
// supported modes:
// bilinear (half_pixel), bilinear with align_corners,
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
// (asymmetric), nearest with align_corners, nearest_half_pixel,
// nearest_pytorch_half_pixel
if (mode == "linear") {
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
"bilinear");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
std::string modeStr;
switch (rank) {
case 3:
modeStr = "linear";
break;
case 4:
modeStr = "bilinear";
break;
case 5:
modeStr = "trilinear";
break;
default:
return failure();
}
// Confusingly enough, the default coordTfMode for pytorch bilinear
// mode is apparently half_pixel, NOT pytorch_half_pixel
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
if (mode == "nearest") {
std::string modeStr = "nearest";
// The default coordTfMode for pytorch with mode = nearest is
// apparently asymmetric
if (coordTfMode != "asymmetric" && coordTfMode != "align_corners")
modeStr = (modeStr + "_") + coordTfMode;
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizesOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizesOperand);
}
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
}
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {

View File

@ -2671,7 +2671,9 @@ public:
static Value NearestInterpolate(OpBuilder &b, Location loc,
SmallVector<Value> outputSizes, Value input,
SmallVector<Value> inputSizes) {
SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
@ -2692,7 +2694,11 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
// scale = length_resized / length_original
// x_original = x_resized / scale
Value scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);
Value scale;
if (scaleValues.empty())
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);
else
scale = scaleValues[i - 2];
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), outIndex);
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
@ -2715,167 +2721,139 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, SmallVector<Value> outputSizes,
Value input, SmallVector<Value> inputSizes) {
Value inputSizeH = inputSizes[0];
Value inputSizeW = inputSizes[1];
Value outputSizeH = outputSizes[0];
Value outputSizeW = outputSizes[1];
int hDimOffset = 2;
Value input, SmallVector<Value> inputSizes,
SmallVector<Value> scaleValues,
std::string coordStr) {
unsigned dimOffset = 2;
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
Value cstOneEps = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.001));
Value cstOneEps =
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.000001));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value yOut = b.create<linalg::IndexOp>(loc, 2);
Value xOut = b.create<linalg::IndexOp>(loc, 3);
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
Value yProj, xProj;
if (alignCornersBool) {
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneFloat);
Value outputSizeHSubOne =
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneFloat);
Value hScale =
b.create<arith::DivFOp>(loc, inputHSubOne, outputSizeHSubOne);
Value yProjBeforeClamp = b.create<arith::MulFOp>(loc, yOutFP, hScale);
Value yMax = b.create<arith::MaximumFOp>(loc, yProjBeforeClamp, zero);
Value outputSizeHSubOneEps =
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneEps);
yProj = b.create<arith::MinimumFOp>(loc, outputSizeHSubOneEps, yMax);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneFloat);
Value outputSizeWSubOne =
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneFloat);
Value wScale =
b.create<arith::DivFOp>(loc, inputWSubOne, outputSizeWSubOne);
Value xProjBeforeClamp = b.create<arith::MulFOp>(loc, xOutFP, wScale);
Value xMax = b.create<arith::MaximumFOp>(loc, xProjBeforeClamp, zero);
Value outputSizeWSubOneEps =
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneEps);
xProj = b.create<arith::MinimumFOp>(loc, outputSizeWSubOneEps, xMax);
} else {
// y_original = (y_resized + 0.5) / scale - 0.5
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value yPlusHalf = b.create<arith::AddFOp>(loc, yOutFP, cstHalf);
Value yDivScale = b.create<arith::DivFOp>(loc, yPlusHalf, hScale);
Value ySubHalf = b.create<arith::SubFOp>(loc, yDivScale, cstHalf);
Value yMax = b.create<arith::MaximumFOp>(loc, ySubHalf, zero);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneEps);
yProj = b.create<arith::MinimumFOp>(loc, yMax, inputHSubOne);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value xPlusHalf = b.create<arith::AddFOp>(loc, xOutFP, cstHalf);
Value xDivScale = b.create<arith::DivFOp>(loc, xPlusHalf, wScale);
Value xSubHalf = b.create<arith::SubFOp>(loc, xDivScale, cstHalf);
// clamp
Value xMax = b.create<arith::MaximumFOp>(loc, xSubHalf, zero);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneEps);
xProj = b.create<arith::MinimumFOp>(loc, xMax, inputWSubOne);
}
Value yLow = b.create<math::FloorOp>(loc, yProj);
Value yProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, yProj);
Value yHigh = b.create<math::FloorOp>(loc, yProjPlusOne);
Value xLow = b.create<math::FloorOp>(loc, xProj);
Value xProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, xProj);
Value xHigh = b.create<math::FloorOp>(loc, xProjPlusOne);
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
Value yLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yLow);
Value yLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yLowInt);
Value xLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xLow);
Value xLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xLowInt);
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
// length_original
Value inputFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
// length_resized
Value outputSizeFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizes[i]);
// scale = length_resized/length_original
Value scale;
if (alignCornersBool) {
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
Value outputSizeSubOne =
b.create<arith::SubFOp>(loc, outputSizeFP, cstOneFloat);
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
outputSizeSubOne, zero);
scale = b.create<arith::DivFOp>(loc, inputSubOne, outputSizeSubOne);
scale = b.create<arith::SelectOp>(loc, cmp, zero, scale);
coordStr = "_align_corners";
} else if (scaleValues.empty())
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputFP);
else
scale = scaleValues[i];
// y_resized
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(),
indices[i + dimOffset]);
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
Value preClip;
if (coordStr == "_align_corners") {
preClip = b.create<arith::MulFOp>(loc, outFP, scale);
}
if (coordStr == "_asymmetric") {
preClip = b.create<arith::DivFOp>(loc, outFP, scale);
}
if (coordStr == "_pytorch_half_pixel" || coordStr == "") {
// half-pixel modes
// y_resized + 0.5
Value outPlusHalf = b.create<arith::AddFOp>(loc, outFP, cstHalf);
// (y_resized + 0.5) / scale
Value outDivScale = b.create<arith::DivFOp>(loc, outPlusHalf, scale);
// _ - 0.5
preClip = b.create<arith::SubFOp>(loc, outDivScale, cstHalf);
}
// for pytorch half pixel , special case for length_resized == 1:
if (coordStr == "_pytorch_half_pixel") {
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
outputSizeFP, cstOneFloat);
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
}
// clip to 0,inf
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
// length_original - 1.001
Value inputSubOneEps = b.create<arith::SubFOp>(loc, inputFP, cstOneEps);
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
// clip to [0,length_original - 1.001]
projEps.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOneEps));
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
Value yHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yHigh);
Value yHighIdx =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yHighInt);
lowFP.push_back(b.create<math::FloorOp>(loc, projEps[i]));
Value projPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, projEps[i]);
highFP.push_back(b.create<math::FloorOp>(loc, projPlusOne));
Value xHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xHigh);
Value xHighIdx =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xHighInt);
Value lowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowFP[i]);
low.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowInt));
indices[hDimOffset] = yLowIdx;
indices[hDimOffset + 1] = xLowIdx;
Value highInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), highFP[i]);
high.push_back(
b.create<arith::IndexCastOp>(loc, b.getIndexType(), highInt));
}
SmallVector<Value> cornerValues;
indices[dimOffset] = low[0];
indices[dimOffset + 1] = low[1];
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yLowIdx;
indices[hDimOffset + 1] = xHighIdx;
indices[dimOffset] = low[0];
indices[dimOffset + 1] = high[1];
Value p01 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yHighIdx;
indices[hDimOffset + 1] = xLowIdx;
indices[dimOffset] = high[0];
indices[dimOffset + 1] = low[1];
Value p10 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yHighIdx;
indices[hDimOffset + 1] = xHighIdx;
indices[dimOffset] = high[0];
indices[dimOffset + 1] = high[1];
Value p11 = b.create<tensor::ExtractOp>(loc, input, indices);
// p00 p01
// p10 p11
// (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) /
// (xhigh - xlow) * p01
Value xHighMinusxProj = b.create<arith::SubFOp>(loc, xHigh, xProj);
Value xHighMinusxLow = b.create<arith::SubFOp>(loc, xHigh, xLow);
Value w0 = b.create<arith::DivFOp>(loc, xHighMinusxProj, xHighMinusxLow);
Value lhs = b.create<arith::MulFOp>(loc, w0, p00);
// Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)),
// where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc.
// We interpolate via the weighted average of pij by weights Aij
// the formula is retval = Sum(pij*Aij for i and j in range(2))
// Note: we do not need to divide by total rect area == 1
Value xProjMinusxLow = b.create<arith::SubFOp>(loc, xProj, xLow);
Value w1 = b.create<arith::DivFOp>(loc, xProjMinusxLow, xHighMinusxLow);
Value rhs = b.create<arith::MulFOp>(loc, w1, p01);
// lengths : Aij == dyi*dxj
Value dy0 = b.create<arith::SubFOp>(loc, highFP[0], proj[0]);
Value dy1 = b.create<arith::SubFOp>(loc, proj[0], lowFP[0]);
Value dx0 = b.create<arith::SubFOp>(loc, highFP[1], proj[1]);
Value dx1 = b.create<arith::SubFOp>(loc, proj[1], lowFP[1]);
Value xInter = b.create<arith::AddFOp>(loc, lhs, rhs);
// left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01)
Value dx0p00 = b.create<arith::MulFOp>(loc, dx0, p00);
Value dx1p01 = b.create<arith::MulFOp>(loc, dx1, p01);
Value sum = b.create<arith::AddFOp>(loc, dx0p00, dx1p01);
Value left = b.create<arith::MulFOp>(loc, dy0, sum);
// right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11)
Value dx0p10 = b.create<arith::MulFOp>(loc, dx0, p10);
Value dx1p11 = b.create<arith::MulFOp>(loc, dx1, p11);
sum = b.create<arith::AddFOp>(loc, dx0p10, dx1p11);
Value right = b.create<arith::MulFOp>(loc, dy1, sum);
// (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) /
// (xhigh - xlow) * p11
lhs = b.create<arith::MulFOp>(loc, w0, p10);
rhs = b.create<arith::MulFOp>(loc, w1, p11);
Value xInter1 = b.create<arith::AddFOp>(loc, lhs, rhs);
// (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow)
// / (yhigh - ylow) * xInter1
Value yHighMinusyProj = b.create<arith::SubFOp>(loc, yHigh, yProj);
Value yHighMinusyLow = b.create<arith::SubFOp>(loc, yHigh, yLow);
w0 = b.create<arith::DivFOp>(loc, yHighMinusyProj, yHighMinusyLow);
lhs = b.create<arith::MulFOp>(loc, w0, xInter);
Value yProjMinusyLow = b.create<arith::SubFOp>(loc, yProj, yLow);
w1 = b.create<arith::DivFOp>(loc, yProjMinusyLow, yHighMinusyLow);
rhs = b.create<arith::MulFOp>(loc, w1, xInter1);
Value retVal = b.create<arith::AddFOp>(loc, lhs, rhs);
return retVal;
return b.create<arith::AddFOp>(loc, left, right);
}
namespace {
@ -2888,8 +2866,12 @@ public:
ConversionPatternRewriter &rewriter) const override {
std::string mode;
// note: to support onnx.Resize, we are passing some extra options through
// the mode attribute. For example, onnx.Resize with mode="linear" and
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
// op with the non-standard mode="bilinear_asymmetric".
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode != "bilinear" && mode != "nearest") {
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
return failure();
}
@ -2897,41 +2879,46 @@ public:
Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (mode == "bilinear" && inputRank != 4)
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
return rewriter.notifyMatchFailure(
op,
"cannot perform bilinear interpolation when input spatial dims != 2");
SmallVector<Value> outputSizeIntValues;
SmallVector<Value> inputSizes;
SmallVector<Value> ScaleFactorFloatValues;
for (unsigned i = 2; i < inputRank; i++) {
Value inputSize = getDimOp(rewriter, loc, input, 2);
Value inputSize = getDimOp(rewriter, loc, input, i);
inputSizes.push_back(rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIntegerType(64), inputSize));
}
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
bool recompScale;
if (!matchPattern(op.getRecomputeScaleFactor(),
m_TorchConstantBool(&recompScale)))
recompScale = false;
SmallVector<Value> ScaleFactorTorchFloat;
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
SmallVector<Value> ScaleFactorFloatValues;
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
for (unsigned i = 0; i < inputRank - 2; i++) {
Value inputSizeFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizes[i]);
Value scale = rewriter.create<arith::TruncFOp>(
ScaleFactorFloatValues[i] = rewriter.create<arith::TruncFOp>(
loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]);
Value outputSize =
rewriter.create<arith::MulFOp>(loc, inputSizeFP, scale);
Value outputSize = rewriter.create<arith::MulFOp>(
loc, inputSizeFP, ScaleFactorFloatValues[i]);
outputSize = rewriter.create<math::FloorOp>(loc, outputSize);
outputSize = rewriter.create<arith::FPToSIOp>(
loc, rewriter.getI64Type(), outputSize);
outputSizeIntValues.push_back(outputSize);
}
if (recompScale)
ScaleFactorFloatValues.clear();
} else {
SmallVector<Value> outputSizeTorchInt;
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
@ -2948,12 +2935,9 @@ public:
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value finalRes =
rewriter
.create<linalg::GenericOp>(
@ -2962,12 +2946,14 @@ public:
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value retVal;
if (mode == "nearest") {
retVal = NearestInterpolate(b, loc, outputSizeIntValues,
input, inputSizes);
} else if (mode == "bilinear") {
if (mode.substr(0, 7) == "nearest") {
retVal = NearestInterpolate(
b, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(7));
} else if (mode.substr(0, 8) == "bilinear") {
retVal = BilinearInterpolate(
b, op, loc, outputSizeIntValues, input, inputSizes);
b, op, loc, outputSizeIntValues, input, inputSizes,
ScaleFactorFloatValues, mode.substr(8));
}
b.create<linalg::YieldOp>(loc, retVal);
})

View File

@ -2417,7 +2417,7 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
//===----------------------------------------------------------------------===//
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
StringAttr item = dyn_cast_or_null<StringAttr>(adaptor.getItem());
if (!item)
return nullptr;

View File

@ -22,6 +22,12 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"SplitWithSizes_Module_basic",
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
# these interpolate tests are added specifically to test onnx.Resize.
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
}
LINALG_CRASHING_SET = {
@ -3089,6 +3095,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
"IndexTensorSelectDimModule_basic",
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"IntFloatModule_basic",
"IntImplicitModule_basic",
"IsFloatingPointFloat_True",
@ -3933,6 +3943,10 @@ ONNX_TOSA_XFAIL_SET = {
"IndexTensorStaticContiguousWithNoneModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorStaticNonContiguousWithNoneModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
"InterpolateDynamicModule_scales_recompute_bilinear",
"IntFloatModule_basic",
"IntImplicitModule_basic",
"IouOfModule_basic",

View File

@ -1367,3 +1367,97 @@ class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
)
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))
class InterpolateModule(torch.nn.Module):
def __init__(
self,
size=None,
scale_factor=None,
mode="nearest",
align_corners=None,
recompute_scale_factor=None,
antialias=False,
):
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias
super().__init__()
def _forward(self, input):
return torch.nn.functional.interpolate(
input,
size=self.size,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)
class InterpolateStaticModule(InterpolateModule):
@export
@annotate_args(
[
None,
([1, 1, 4, 5], torch.float32, True),
]
)
def forward(self, input):
return self._forward(input)
class InterpolateDynamicModule(InterpolateModule):
@export
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
]
)
def forward(self, input):
return self._forward(input)
@register_test_case(
module_factory=lambda: InterpolateStaticModule(
scale_factor=0.41, mode="bilinear", align_corners=True
)
)
def InterpolateStaticModule_scales_bilinear_align_corners(module, tu: TestUtils):
input = torch.arange(20).to(dtype=torch.float32)
input = input.reshape((1, 1, 4, 5))
module.forward(input)
@register_test_case(
module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="nearest")
)
def InterpolateDynamicModule_sizes_nearest(module, tu: TestUtils):
input = torch.arange(20).to(dtype=torch.float32)
input = input.reshape((1, 1, 4, 5))
module.forward(input)
@register_test_case(
module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="bilinear")
)
def InterpolateDynamicModule_sizes_bilinear(module, tu: TestUtils):
input = torch.arange(20).to(dtype=torch.float32)
input = input.reshape((1, 1, 4, 5))
module.forward(input)
@register_test_case(
module_factory=lambda: InterpolateDynamicModule(
scale_factor=(1.9, 2.4), mode="bilinear", recompute_scale_factor=True
)
)
def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils):
input = torch.arange(20).to(dtype=torch.float32)
input = input.reshape((1, 1, 4, 5))
module.forward(input)

View File

@ -4,75 +4,19 @@
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[generic:.*]] = linalg.generic
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32
// CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32
// CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
// CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32
// CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32
// CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32
// CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32
// CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32
// CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32
// CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32
// CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32
// CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32
// CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32
// CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32
// CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32
// CHECK: %[[x43:.*]] = linalg.index 0 : index
// CHECK: %[[x44:.*]] = linalg.index 1 : index
// CHECK: %[[x45:.*]] = linalg.index 2 : index
// CHECK: %[[x46:.*]] = linalg.index 3 : index
// CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64
// CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index
// CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64
// CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index
// CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64
// CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index
// CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64
// CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32>
// CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32
// CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32
// CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32
// CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32
// CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32
// CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32
// CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32
// CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32
// CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32
// CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32
// CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32
// CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32
// CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32
// CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32
// CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32
// CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32
// CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32
// CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32
// CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
// CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]]
// CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]]
// CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]
// CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]]
// CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]]
// CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]]
// CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]]
// CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]]
// CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]]
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0