mirror of https://github.com/llvm/torch-mlir
Modifies onnx resize lowering to fix numerical issues (#3381)
Updates: - some unsupported modes are now going to report a match failure for unsupported coordinate transformation modes. - fixes a bug that was introduced in the last patch for resize (my bad...) - uses actual x and y coordinates for computing weights in bilinear interpolation (rather than eps modified values) - slightly simplifies the bilinear interpolation payload for readability and performance - passes coordinate transformation mode information from an onnx.Resize op to the mode string for the aten._interpolate op. This allows us to perform custom logic in the torch->linalg lowering to support onnx.Resize options without losing the default behaviors of the interpolate op.pull/3405/head
parent
d7b8f00d01
commit
074098d20c
|
@ -2784,12 +2784,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
|
||||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
|
||||
return failure();
|
||||
|
||||
if (coordTfMode == "tf_crop_and_resize")
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: coordinate transformation mode: "
|
||||
"tf_crop_and_resize");
|
||||
if (mode == "nearest" && nearest_mode != "floor") {
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unimplemented: support not present for nearest_mode "
|
||||
"except floor");
|
||||
}
|
||||
unsigned rank = dyn_cast<Torch::ValueTensorType>(operands[0].getType())
|
||||
.getSizes()
|
||||
.size();
|
||||
|
||||
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||
|
@ -2851,36 +2857,54 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
Value sizesValueList = noneVal;
|
||||
Value alignCorners =
|
||||
coordTfMode == "align_corners" ? cstTrue : cstFalse;
|
||||
|
||||
if (mode == "cubic") {
|
||||
return rewriter.notifyMatchFailure(binder.op,
|
||||
"unimplemented: bicubic mode");
|
||||
}
|
||||
// supported modes:
|
||||
// bilinear (half_pixel), bilinear with align_corners,
|
||||
// bilinear_pytorch_half_pixel, bilinear_asymmetric nearest
|
||||
// (asymmetric), nearest with align_corners, nearest_half_pixel,
|
||||
// nearest_pytorch_half_pixel
|
||||
if (mode == "linear") {
|
||||
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
|
||||
"bilinear");
|
||||
if (operands.size() < 4) {
|
||||
Value scaleOperand = operands[2];
|
||||
scalesValueList = getValueList(scaleOperand);
|
||||
sizesValueList = noneVal;
|
||||
} else {
|
||||
Value sizeOperand = operands[3];
|
||||
scalesValueList = noneVal;
|
||||
sizesValueList = getValueList(sizeOperand);
|
||||
std::string modeStr;
|
||||
switch (rank) {
|
||||
case 3:
|
||||
modeStr = "linear";
|
||||
break;
|
||||
case 4:
|
||||
modeStr = "bilinear";
|
||||
break;
|
||||
case 5:
|
||||
modeStr = "trilinear";
|
||||
break;
|
||||
default:
|
||||
return failure();
|
||||
}
|
||||
// Confusingly enough, the default coordTfMode for pytorch bilinear
|
||||
// mode is apparently half_pixel, NOT pytorch_half_pixel
|
||||
if (coordTfMode != "half_pixel" && coordTfMode != "align_corners")
|
||||
modeStr = (modeStr + "_") + coordTfMode;
|
||||
modeStrValue =
|
||||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
|
||||
}
|
||||
if (mode == "nearest") {
|
||||
std::string modeStr = "nearest";
|
||||
// The default coordTfMode for pytorch with mode = nearest is
|
||||
// apparently asymmetric
|
||||
if (coordTfMode != "asymmetric" && coordTfMode != "align_corners")
|
||||
modeStr = (modeStr + "_") + coordTfMode;
|
||||
modeStrValue =
|
||||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
|
||||
if (operands.size() < 4) {
|
||||
Value scaleOperand = operands[2];
|
||||
scalesValueList = getValueList(scaleOperand);
|
||||
sizesValueList = noneVal;
|
||||
} else {
|
||||
Value sizesOperand = operands[3];
|
||||
scalesValueList = noneVal;
|
||||
sizesValueList = getValueList(sizesOperand);
|
||||
}
|
||||
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), modeStr);
|
||||
}
|
||||
if (operands.size() < 4) {
|
||||
Value scaleOperand = operands[2];
|
||||
scalesValueList = getValueList(scaleOperand);
|
||||
sizesValueList = noneVal;
|
||||
} else {
|
||||
Value sizeOperand = operands[3];
|
||||
scalesValueList = noneVal;
|
||||
sizesValueList = getValueList(sizeOperand);
|
||||
}
|
||||
if (scalesValueList.getType().isa<Torch::NoneType>() &&
|
||||
sizesValueList.getType().isa<Torch::NoneType>()) {
|
||||
|
|
|
@ -2671,7 +2671,9 @@ public:
|
|||
|
||||
static Value NearestInterpolate(OpBuilder &b, Location loc,
|
||||
SmallVector<Value> outputSizes, Value input,
|
||||
SmallVector<Value> inputSizes) {
|
||||
SmallVector<Value> inputSizes,
|
||||
SmallVector<Value> scaleValues,
|
||||
std::string coordStr) {
|
||||
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inputType.getRank();
|
||||
|
@ -2692,7 +2694,11 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
|
|||
|
||||
// scale = length_resized / length_original
|
||||
// x_original = x_resized / scale
|
||||
Value scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);
|
||||
Value scale;
|
||||
if (scaleValues.empty())
|
||||
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputSizeFP);
|
||||
else
|
||||
scale = scaleValues[i - 2];
|
||||
|
||||
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), outIndex);
|
||||
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
|
||||
|
@ -2715,167 +2721,139 @@ static Value NearestInterpolate(OpBuilder &b, Location loc,
|
|||
static Value BilinearInterpolate(OpBuilder &b,
|
||||
Aten__InterpolateSizeListScaleListOp op,
|
||||
Location loc, SmallVector<Value> outputSizes,
|
||||
Value input, SmallVector<Value> inputSizes) {
|
||||
Value inputSizeH = inputSizes[0];
|
||||
Value inputSizeW = inputSizes[1];
|
||||
Value outputSizeH = outputSizes[0];
|
||||
Value outputSizeW = outputSizes[1];
|
||||
|
||||
int hDimOffset = 2;
|
||||
Value input, SmallVector<Value> inputSizes,
|
||||
SmallVector<Value> scaleValues,
|
||||
std::string coordStr) {
|
||||
unsigned dimOffset = 2;
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inputType.getRank();
|
||||
|
||||
Value cstOneEps = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.001));
|
||||
Value cstOneEps =
|
||||
b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.000001));
|
||||
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
|
||||
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
|
||||
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
|
||||
|
||||
Value yOut = b.create<linalg::IndexOp>(loc, 2);
|
||||
Value xOut = b.create<linalg::IndexOp>(loc, 3);
|
||||
|
||||
bool alignCornersBool;
|
||||
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
|
||||
|
||||
Value yProj, xProj;
|
||||
if (alignCornersBool) {
|
||||
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
|
||||
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
|
||||
Value outputSizeHFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
|
||||
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
|
||||
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
|
||||
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneFloat);
|
||||
Value outputSizeHSubOne =
|
||||
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneFloat);
|
||||
Value hScale =
|
||||
b.create<arith::DivFOp>(loc, inputHSubOne, outputSizeHSubOne);
|
||||
Value yProjBeforeClamp = b.create<arith::MulFOp>(loc, yOutFP, hScale);
|
||||
Value yMax = b.create<arith::MaximumFOp>(loc, yProjBeforeClamp, zero);
|
||||
Value outputSizeHSubOneEps =
|
||||
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneEps);
|
||||
yProj = b.create<arith::MinimumFOp>(loc, outputSizeHSubOneEps, yMax);
|
||||
|
||||
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
|
||||
Value outputSizeWFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
|
||||
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
|
||||
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
|
||||
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneFloat);
|
||||
Value outputSizeWSubOne =
|
||||
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneFloat);
|
||||
Value wScale =
|
||||
b.create<arith::DivFOp>(loc, inputWSubOne, outputSizeWSubOne);
|
||||
Value xProjBeforeClamp = b.create<arith::MulFOp>(loc, xOutFP, wScale);
|
||||
Value xMax = b.create<arith::MaximumFOp>(loc, xProjBeforeClamp, zero);
|
||||
Value outputSizeWSubOneEps =
|
||||
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneEps);
|
||||
xProj = b.create<arith::MinimumFOp>(loc, outputSizeWSubOneEps, xMax);
|
||||
} else {
|
||||
// y_original = (y_resized + 0.5) / scale - 0.5
|
||||
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
|
||||
Value outputSizeHFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
|
||||
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
|
||||
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
|
||||
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
|
||||
Value yPlusHalf = b.create<arith::AddFOp>(loc, yOutFP, cstHalf);
|
||||
Value yDivScale = b.create<arith::DivFOp>(loc, yPlusHalf, hScale);
|
||||
Value ySubHalf = b.create<arith::SubFOp>(loc, yDivScale, cstHalf);
|
||||
Value yMax = b.create<arith::MaximumFOp>(loc, ySubHalf, zero);
|
||||
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneEps);
|
||||
yProj = b.create<arith::MinimumFOp>(loc, yMax, inputHSubOne);
|
||||
|
||||
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
|
||||
Value outputSizeWFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
|
||||
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
|
||||
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
|
||||
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
|
||||
Value xPlusHalf = b.create<arith::AddFOp>(loc, xOutFP, cstHalf);
|
||||
Value xDivScale = b.create<arith::DivFOp>(loc, xPlusHalf, wScale);
|
||||
Value xSubHalf = b.create<arith::SubFOp>(loc, xDivScale, cstHalf);
|
||||
// clamp
|
||||
Value xMax = b.create<arith::MaximumFOp>(loc, xSubHalf, zero);
|
||||
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneEps);
|
||||
xProj = b.create<arith::MinimumFOp>(loc, xMax, inputWSubOne);
|
||||
}
|
||||
Value yLow = b.create<math::FloorOp>(loc, yProj);
|
||||
Value yProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, yProj);
|
||||
Value yHigh = b.create<math::FloorOp>(loc, yProjPlusOne);
|
||||
|
||||
Value xLow = b.create<math::FloorOp>(loc, xProj);
|
||||
Value xProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, xProj);
|
||||
Value xHigh = b.create<math::FloorOp>(loc, xProjPlusOne);
|
||||
|
||||
SmallVector<Value> indices;
|
||||
for (unsigned i = 0; i < inputRank; i++) {
|
||||
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||
}
|
||||
Value yLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yLow);
|
||||
Value yLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yLowInt);
|
||||
|
||||
Value xLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xLow);
|
||||
Value xLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xLowInt);
|
||||
SmallVector<Value> proj, projEps, high, low, highFP, lowFP;
|
||||
for (unsigned i = 0; i < inputRank - dimOffset; i++) {
|
||||
// length_original
|
||||
Value inputFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizes[i]);
|
||||
// length_resized
|
||||
Value outputSizeFP =
|
||||
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizes[i]);
|
||||
// scale = length_resized/length_original
|
||||
Value scale;
|
||||
if (alignCornersBool) {
|
||||
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
|
||||
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
|
||||
Value outputSizeSubOne =
|
||||
b.create<arith::SubFOp>(loc, outputSizeFP, cstOneFloat);
|
||||
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
||||
outputSizeSubOne, zero);
|
||||
scale = b.create<arith::DivFOp>(loc, inputSubOne, outputSizeSubOne);
|
||||
scale = b.create<arith::SelectOp>(loc, cmp, zero, scale);
|
||||
coordStr = "_align_corners";
|
||||
} else if (scaleValues.empty())
|
||||
scale = b.create<arith::DivFOp>(loc, outputSizeFP, inputFP);
|
||||
else
|
||||
scale = scaleValues[i];
|
||||
// y_resized
|
||||
Value outInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(),
|
||||
indices[i + dimOffset]);
|
||||
Value outFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), outInt);
|
||||
Value preClip;
|
||||
if (coordStr == "_align_corners") {
|
||||
preClip = b.create<arith::MulFOp>(loc, outFP, scale);
|
||||
}
|
||||
if (coordStr == "_asymmetric") {
|
||||
preClip = b.create<arith::DivFOp>(loc, outFP, scale);
|
||||
}
|
||||
if (coordStr == "_pytorch_half_pixel" || coordStr == "") {
|
||||
// half-pixel modes
|
||||
// y_resized + 0.5
|
||||
Value outPlusHalf = b.create<arith::AddFOp>(loc, outFP, cstHalf);
|
||||
// (y_resized + 0.5) / scale
|
||||
Value outDivScale = b.create<arith::DivFOp>(loc, outPlusHalf, scale);
|
||||
// _ - 0.5
|
||||
preClip = b.create<arith::SubFOp>(loc, outDivScale, cstHalf);
|
||||
}
|
||||
// for pytorch half pixel , special case for length_resized == 1:
|
||||
if (coordStr == "_pytorch_half_pixel") {
|
||||
Value cmp = b.create<arith::CmpFOp>(loc, arith::CmpFPredicate::UEQ,
|
||||
outputSizeFP, cstOneFloat);
|
||||
preClip = b.create<arith::SelectOp>(loc, cmp, zero, preClip);
|
||||
}
|
||||
// clip to 0,inf
|
||||
Value max = b.create<arith::MaximumFOp>(loc, preClip, zero);
|
||||
// length_original - 1.001
|
||||
Value inputSubOneEps = b.create<arith::SubFOp>(loc, inputFP, cstOneEps);
|
||||
Value inputSubOne = b.create<arith::SubFOp>(loc, inputFP, cstOneFloat);
|
||||
// clip to [0,length_original - 1.001]
|
||||
projEps.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOneEps));
|
||||
proj.push_back(b.create<arith::MinimumFOp>(loc, max, inputSubOne));
|
||||
|
||||
Value yHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yHigh);
|
||||
Value yHighIdx =
|
||||
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yHighInt);
|
||||
lowFP.push_back(b.create<math::FloorOp>(loc, projEps[i]));
|
||||
Value projPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, projEps[i]);
|
||||
highFP.push_back(b.create<math::FloorOp>(loc, projPlusOne));
|
||||
|
||||
Value xHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xHigh);
|
||||
Value xHighIdx =
|
||||
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xHighInt);
|
||||
Value lowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), lowFP[i]);
|
||||
low.push_back(b.create<arith::IndexCastOp>(loc, b.getIndexType(), lowInt));
|
||||
|
||||
indices[hDimOffset] = yLowIdx;
|
||||
indices[hDimOffset + 1] = xLowIdx;
|
||||
Value highInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), highFP[i]);
|
||||
high.push_back(
|
||||
b.create<arith::IndexCastOp>(loc, b.getIndexType(), highInt));
|
||||
}
|
||||
|
||||
SmallVector<Value> cornerValues;
|
||||
indices[dimOffset] = low[0];
|
||||
indices[dimOffset + 1] = low[1];
|
||||
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
|
||||
indices[hDimOffset] = yLowIdx;
|
||||
indices[hDimOffset + 1] = xHighIdx;
|
||||
indices[dimOffset] = low[0];
|
||||
indices[dimOffset + 1] = high[1];
|
||||
Value p01 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
|
||||
indices[hDimOffset] = yHighIdx;
|
||||
indices[hDimOffset + 1] = xLowIdx;
|
||||
indices[dimOffset] = high[0];
|
||||
indices[dimOffset + 1] = low[1];
|
||||
Value p10 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
|
||||
indices[hDimOffset] = yHighIdx;
|
||||
indices[hDimOffset + 1] = xHighIdx;
|
||||
indices[dimOffset] = high[0];
|
||||
indices[dimOffset + 1] = high[1];
|
||||
Value p11 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||
|
||||
// p00 p01
|
||||
// p10 p11
|
||||
// (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) /
|
||||
// (xhigh - xlow) * p01
|
||||
Value xHighMinusxProj = b.create<arith::SubFOp>(loc, xHigh, xProj);
|
||||
Value xHighMinusxLow = b.create<arith::SubFOp>(loc, xHigh, xLow);
|
||||
Value w0 = b.create<arith::DivFOp>(loc, xHighMinusxProj, xHighMinusxLow);
|
||||
Value lhs = b.create<arith::MulFOp>(loc, w0, p00);
|
||||
// Let Aij := area rect((yProj,xProj) <-> (y_i*,x_j*)),
|
||||
// where i* = i+1 mod 2 and x_0 = xLow, x_1 = xHigh etc.
|
||||
// We interpolate via the weighted average of pij by weights Aij
|
||||
// the formula is retval = Sum(pij*Aij for i and j in range(2))
|
||||
// Note: we do not need to divide by total rect area == 1
|
||||
|
||||
Value xProjMinusxLow = b.create<arith::SubFOp>(loc, xProj, xLow);
|
||||
Value w1 = b.create<arith::DivFOp>(loc, xProjMinusxLow, xHighMinusxLow);
|
||||
Value rhs = b.create<arith::MulFOp>(loc, w1, p01);
|
||||
// lengths : Aij == dyi*dxj
|
||||
Value dy0 = b.create<arith::SubFOp>(loc, highFP[0], proj[0]);
|
||||
Value dy1 = b.create<arith::SubFOp>(loc, proj[0], lowFP[0]);
|
||||
Value dx0 = b.create<arith::SubFOp>(loc, highFP[1], proj[1]);
|
||||
Value dx1 = b.create<arith::SubFOp>(loc, proj[1], lowFP[1]);
|
||||
|
||||
Value xInter = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||
// left = A00*p00 + A01*p01 = dy0(dx0p00 + dx1p01)
|
||||
Value dx0p00 = b.create<arith::MulFOp>(loc, dx0, p00);
|
||||
Value dx1p01 = b.create<arith::MulFOp>(loc, dx1, p01);
|
||||
Value sum = b.create<arith::AddFOp>(loc, dx0p00, dx1p01);
|
||||
Value left = b.create<arith::MulFOp>(loc, dy0, sum);
|
||||
// right = A10*p10 + A11*p11 = dy1(dx0p10 + dx1p11)
|
||||
Value dx0p10 = b.create<arith::MulFOp>(loc, dx0, p10);
|
||||
Value dx1p11 = b.create<arith::MulFOp>(loc, dx1, p11);
|
||||
sum = b.create<arith::AddFOp>(loc, dx0p10, dx1p11);
|
||||
Value right = b.create<arith::MulFOp>(loc, dy1, sum);
|
||||
|
||||
// (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) /
|
||||
// (xhigh - xlow) * p11
|
||||
lhs = b.create<arith::MulFOp>(loc, w0, p10);
|
||||
rhs = b.create<arith::MulFOp>(loc, w1, p11);
|
||||
|
||||
Value xInter1 = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||
|
||||
// (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow)
|
||||
// / (yhigh - ylow) * xInter1
|
||||
Value yHighMinusyProj = b.create<arith::SubFOp>(loc, yHigh, yProj);
|
||||
Value yHighMinusyLow = b.create<arith::SubFOp>(loc, yHigh, yLow);
|
||||
w0 = b.create<arith::DivFOp>(loc, yHighMinusyProj, yHighMinusyLow);
|
||||
lhs = b.create<arith::MulFOp>(loc, w0, xInter);
|
||||
|
||||
Value yProjMinusyLow = b.create<arith::SubFOp>(loc, yProj, yLow);
|
||||
w1 = b.create<arith::DivFOp>(loc, yProjMinusyLow, yHighMinusyLow);
|
||||
rhs = b.create<arith::MulFOp>(loc, w1, xInter1);
|
||||
|
||||
Value retVal = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||
return retVal;
|
||||
return b.create<arith::AddFOp>(loc, left, right);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -2888,8 +2866,12 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
std::string mode;
|
||||
// note: to support onnx.Resize, we are passing some extra options through
|
||||
// the mode attribute. For example, onnx.Resize with mode="linear" and
|
||||
// coordinate_transformation_mode="asymmetric" will lower to an interpolate
|
||||
// op with the non-standard mode="bilinear_asymmetric".
|
||||
matchPattern(op.getMode(), m_TorchConstantStr(mode));
|
||||
if (mode != "bilinear" && mode != "nearest") {
|
||||
if (mode.substr(0, 8) != "bilinear" && mode.substr(0, 7) != "nearest") {
|
||||
return failure();
|
||||
}
|
||||
|
||||
|
@ -2897,41 +2879,46 @@ public:
|
|||
Value input = adaptor.getInput();
|
||||
auto inputType = input.getType().cast<RankedTensorType>();
|
||||
auto inputRank = inputType.getRank();
|
||||
if (mode == "bilinear" && inputRank != 4)
|
||||
if (mode.substr(0, 8) == "bilinear" && inputRank != 4)
|
||||
return rewriter.notifyMatchFailure(
|
||||
op,
|
||||
"cannot perform bilinear interpolation when input spatial dims != 2");
|
||||
|
||||
SmallVector<Value> outputSizeIntValues;
|
||||
SmallVector<Value> inputSizes;
|
||||
SmallVector<Value> ScaleFactorFloatValues;
|
||||
for (unsigned i = 2; i < inputRank; i++) {
|
||||
Value inputSize = getDimOp(rewriter, loc, input, 2);
|
||||
Value inputSize = getDimOp(rewriter, loc, input, i);
|
||||
inputSizes.push_back(rewriter.create<arith::IndexCastOp>(
|
||||
loc, rewriter.getIntegerType(64), inputSize));
|
||||
}
|
||||
|
||||
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
||||
bool recompScale;
|
||||
if (!matchPattern(op.getRecomputeScaleFactor(),
|
||||
m_TorchConstantBool(&recompScale)))
|
||||
recompScale = false;
|
||||
SmallVector<Value> ScaleFactorTorchFloat;
|
||||
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "unimplemented: the output_size is not constructed from "
|
||||
"ListConstruct");
|
||||
SmallVector<Value> ScaleFactorFloatValues;
|
||||
ScaleFactorFloatValues = getTypeConvertedValues(
|
||||
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
|
||||
for (unsigned i = 0; i < inputRank - 2; i++) {
|
||||
Value inputSizeFP = rewriter.create<arith::SIToFPOp>(
|
||||
loc, rewriter.getF32Type(), inputSizes[i]);
|
||||
Value scale = rewriter.create<arith::TruncFOp>(
|
||||
ScaleFactorFloatValues[i] = rewriter.create<arith::TruncFOp>(
|
||||
loc, inputSizeFP.getType(), ScaleFactorFloatValues[i]);
|
||||
Value outputSize =
|
||||
rewriter.create<arith::MulFOp>(loc, inputSizeFP, scale);
|
||||
Value outputSize = rewriter.create<arith::MulFOp>(
|
||||
loc, inputSizeFP, ScaleFactorFloatValues[i]);
|
||||
outputSize = rewriter.create<math::FloorOp>(loc, outputSize);
|
||||
outputSize = rewriter.create<arith::FPToSIOp>(
|
||||
loc, rewriter.getI64Type(), outputSize);
|
||||
|
||||
outputSizeIntValues.push_back(outputSize);
|
||||
}
|
||||
if (recompScale)
|
||||
ScaleFactorFloatValues.clear();
|
||||
} else {
|
||||
SmallVector<Value> outputSizeTorchInt;
|
||||
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
|
||||
|
@ -2948,12 +2935,9 @@ public:
|
|||
|
||||
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||
loc, getAsOpFoldResult(dims), inputType.getElementType());
|
||||
|
||||
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||
|
||||
SmallVector<utils::IteratorType> iteratorTypes(
|
||||
inputRank, utils::IteratorType::parallel);
|
||||
|
||||
Value finalRes =
|
||||
rewriter
|
||||
.create<linalg::GenericOp>(
|
||||
|
@ -2962,12 +2946,14 @@ public:
|
|||
/*iteratorTypes=*/iteratorTypes,
|
||||
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||
Value retVal;
|
||||
if (mode == "nearest") {
|
||||
retVal = NearestInterpolate(b, loc, outputSizeIntValues,
|
||||
input, inputSizes);
|
||||
} else if (mode == "bilinear") {
|
||||
if (mode.substr(0, 7) == "nearest") {
|
||||
retVal = NearestInterpolate(
|
||||
b, loc, outputSizeIntValues, input, inputSizes,
|
||||
ScaleFactorFloatValues, mode.substr(7));
|
||||
} else if (mode.substr(0, 8) == "bilinear") {
|
||||
retVal = BilinearInterpolate(
|
||||
b, op, loc, outputSizeIntValues, input, inputSizes);
|
||||
b, op, loc, outputSizeIntValues, input, inputSizes,
|
||||
ScaleFactorFloatValues, mode.substr(8));
|
||||
}
|
||||
b.create<linalg::YieldOp>(loc, retVal);
|
||||
})
|
||||
|
|
|
@ -2417,7 +2417,7 @@ OpFoldResult AtenNeStrOp::fold(FoldAdaptor adaptor) {
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult Aten__Contains__StrListOp::fold(FoldAdaptor adaptor) {
|
||||
StringAttr item = dyn_cast<StringAttr>(adaptor.getItem());
|
||||
StringAttr item = dyn_cast_or_null<StringAttr>(adaptor.getItem());
|
||||
if (!item)
|
||||
return nullptr;
|
||||
|
||||
|
|
|
@ -22,6 +22,12 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | {
|
|||
"IscloseStaticModule_basic",
|
||||
"IscloseStaticModuleTrue_basic",
|
||||
"SplitWithSizes_Module_basic",
|
||||
# lowering to torch backend IR fails due to unsupported op: aten.upsample_[mode/dims].vec
|
||||
# these interpolate tests are added specifically to test onnx.Resize.
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
}
|
||||
|
||||
LINALG_CRASHING_SET = {
|
||||
|
@ -3089,6 +3095,10 @@ FX_IMPORTER_TOSA_XFAIL_SET = {
|
|||
"IndexTensorSelectDimModule_basic",
|
||||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
"IsFloatingPointFloat_True",
|
||||
|
@ -3933,6 +3943,10 @@ ONNX_TOSA_XFAIL_SET = {
|
|||
"IndexTensorStaticContiguousWithNoneModule_basic",
|
||||
"IndexTensorStaticModule_basic",
|
||||
"IndexTensorStaticNonContiguousWithNoneModule_basic",
|
||||
"InterpolateDynamicModule_sizes_bilinear",
|
||||
"InterpolateDynamicModule_sizes_nearest",
|
||||
"InterpolateStaticModule_scales_bilinear_align_corners",
|
||||
"InterpolateDynamicModule_scales_recompute_bilinear",
|
||||
"IntFloatModule_basic",
|
||||
"IntImplicitModule_basic",
|
||||
"IouOfModule_basic",
|
||||
|
|
|
@ -1367,3 +1367,97 @@ class EinsumStaticWithEllipsisSlicingAndBroadcastModule(torch.nn.Module):
|
|||
)
|
||||
def EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(2, 6, 4, 5), tu.rand(6, 5))
|
||||
|
||||
|
||||
class InterpolateModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
size=None,
|
||||
scale_factor=None,
|
||||
mode="nearest",
|
||||
align_corners=None,
|
||||
recompute_scale_factor=None,
|
||||
antialias=False,
|
||||
):
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
self.recompute_scale_factor = recompute_scale_factor
|
||||
self.antialias = antialias
|
||||
super().__init__()
|
||||
|
||||
def _forward(self, input):
|
||||
return torch.nn.functional.interpolate(
|
||||
input,
|
||||
size=self.size,
|
||||
scale_factor=self.scale_factor,
|
||||
mode=self.mode,
|
||||
align_corners=self.align_corners,
|
||||
recompute_scale_factor=self.recompute_scale_factor,
|
||||
antialias=self.antialias,
|
||||
)
|
||||
|
||||
|
||||
class InterpolateStaticModule(InterpolateModule):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([1, 1, 4, 5], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return self._forward(input)
|
||||
|
||||
|
||||
class InterpolateDynamicModule(InterpolateModule):
|
||||
@export
|
||||
@annotate_args(
|
||||
[
|
||||
None,
|
||||
([-1, -1, -1, -1], torch.float32, True),
|
||||
]
|
||||
)
|
||||
def forward(self, input):
|
||||
return self._forward(input)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: InterpolateStaticModule(
|
||||
scale_factor=0.41, mode="bilinear", align_corners=True
|
||||
)
|
||||
)
|
||||
def InterpolateStaticModule_scales_bilinear_align_corners(module, tu: TestUtils):
|
||||
input = torch.arange(20).to(dtype=torch.float32)
|
||||
input = input.reshape((1, 1, 4, 5))
|
||||
module.forward(input)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="nearest")
|
||||
)
|
||||
def InterpolateDynamicModule_sizes_nearest(module, tu: TestUtils):
|
||||
input = torch.arange(20).to(dtype=torch.float32)
|
||||
input = input.reshape((1, 1, 4, 5))
|
||||
module.forward(input)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: InterpolateDynamicModule(size=(2, 7), mode="bilinear")
|
||||
)
|
||||
def InterpolateDynamicModule_sizes_bilinear(module, tu: TestUtils):
|
||||
input = torch.arange(20).to(dtype=torch.float32)
|
||||
input = input.reshape((1, 1, 4, 5))
|
||||
module.forward(input)
|
||||
|
||||
|
||||
@register_test_case(
|
||||
module_factory=lambda: InterpolateDynamicModule(
|
||||
scale_factor=(1.9, 2.4), mode="bilinear", recompute_scale_factor=True
|
||||
)
|
||||
)
|
||||
def InterpolateDynamicModule_scales_recompute_bilinear(module, tu: TestUtils):
|
||||
input = torch.arange(20).to(dtype=torch.float32)
|
||||
input = input.reshape((1, 1, 4, 5))
|
||||
module.forward(input)
|
||||
|
|
|
@ -4,75 +4,19 @@
|
|||
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
||||
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[generic:.*]] = linalg.generic
|
||||
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
|
||||
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
|
||||
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
|
||||
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
|
||||
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
|
||||
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
|
||||
// CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32
|
||||
// CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32
|
||||
// CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32
|
||||
// CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32
|
||||
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
|
||||
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
|
||||
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
|
||||
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64:.*]] : i64 to f32
|
||||
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
|
||||
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
|
||||
// CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32
|
||||
// CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32
|
||||
// CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32
|
||||
// CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32
|
||||
// CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32
|
||||
// CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32
|
||||
// CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32
|
||||
// CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32
|
||||
// CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32
|
||||
// CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32
|
||||
// CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32
|
||||
// CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32
|
||||
// CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32
|
||||
// CHECK: %[[x43:.*]] = linalg.index 0 : index
|
||||
// CHECK: %[[x44:.*]] = linalg.index 1 : index
|
||||
// CHECK: %[[x45:.*]] = linalg.index 2 : index
|
||||
// CHECK: %[[x46:.*]] = linalg.index 3 : index
|
||||
// CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64
|
||||
// CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index
|
||||
// CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64
|
||||
// CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index
|
||||
// CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64
|
||||
// CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index
|
||||
// CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64
|
||||
// CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index
|
||||
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32>
|
||||
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32>
|
||||
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32>
|
||||
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32>
|
||||
// CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32
|
||||
// CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32
|
||||
// CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32
|
||||
// CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32
|
||||
// CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32
|
||||
// CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32
|
||||
// CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32
|
||||
// CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32
|
||||
// CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32
|
||||
// CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32
|
||||
// CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32
|
||||
// CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32
|
||||
// CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32
|
||||
// CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32
|
||||
// CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32
|
||||
// CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32
|
||||
// CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32
|
||||
// CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32
|
||||
// CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32
|
||||
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x1:.*]], %[[x2:.*]], %[[x3:.*]], %[[x4:.*]]] : tensor<1x1x2x4xf32>
|
||||
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
|
||||
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
|
||||
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x1]], %[[x2]]
|
||||
// CHECK: %[[dx0p00:.*]] = arith.mulf %[[dx0:.*]], %[[extracted]]
|
||||
// CHECK: %[[dx1p01:.*]] = arith.mulf %[[dx1:.*]], %[[extracted_7]]
|
||||
// CHECK: %[[sum:.*]] = arith.addf %[[dx0p00]], %[[dx1p01]]
|
||||
// CHECK: %[[left:.*]] = arith.mulf %[[dy0:.*]], %[[sum]]
|
||||
// CHECK: %[[dx0p10:.*]] = arith.mulf %[[dx0]], %[[extracted_8]]
|
||||
// CHECK: %[[dx1p11:.*]] = arith.mulf %[[dx1]], %[[extracted_9]]
|
||||
// CHECK: %[[sum2:.*]] = arith.addf %[[dx0p10]], %[[dx1p11]]
|
||||
// CHECK: %[[right:.*]] = arith.mulf %[[dy1:.*]], %[[sum2]]
|
||||
// CHECK: %[[retval:.*]] = arith.addf %[[left]], %[[right]]
|
||||
%none = torch.constant.none
|
||||
%none_0 = torch.constant.none
|
||||
%int0 = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue