OnnxToTorch lowering resize op (#3013)

https://github.com/nod-ai/SHARK-Turbine/issues/358
adds a lowering from onnx to linalg for bilinear and nearest resize with
support for using scales or sizes to get resize shape. uses coordinate
transform half pixel for bilinear mode and asymmetrical for nearest
mode. See
https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added
two passes -- one for bilinear and the other for nearest.
pull/3304/head
aldesilv 2024-05-08 14:35:03 -07:00 committed by GitHub
parent bce800a3f4
commit ec6d7aa5d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 776 additions and 0 deletions

View File

@ -7260,6 +7260,35 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
}]; }];
} }
def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchOptionalListOfTorchIntType:$size,
AnyTorchOptionalListOfTorchFloatType:$scale_factor,
Torch_StringType:$mode,
AnyTorchOptionalBoolType:$align_corners,
AnyTorchOptionalBoolType:$recompute_scale_factor,
Torch_BoolType:$antialias
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [ def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
AllowsTypeRefinement, AllowsTypeRefinement,
HasValueSemantics, HasValueSemantics,

View File

@ -2637,4 +2637,156 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
rewriter.replaceOp(binder.op, {loss, logProb}); rewriter.replaceOp(binder.op, {loss, logProb});
return success(); return success();
}); });
patterns.onOp(
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
llvm::SmallVector<Value> operands;
std::string mode, nearest_mode, coordTfMode;
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
if (auto attr = binder.op->getAttr("torch.onnx.antialias")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for antialias attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for axes attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"exclude_outside attribute");
}
if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"extrapolation_value attribute");
}
if (auto attr =
binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for "
"keep_aspect_ratio_policy attribute");
}
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
return failure();
if (mode == "nearest" && nearest_mode != "floor") {
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: support not present for nearest_mode "
"except floor");
}
Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value cstTrue =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
Value modeStrValue;
auto extract = [&rewriter, &binder](Value x, Value v) {
auto xTy = x.getType().cast<Torch::ValueTensorType>();
Type extractTy = rewriter.getType<Torch::FloatType>();
if (isa<IntegerType>(xTy.getDtype()))
extractTy = rewriter.getType<Torch::IntType>();
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
v);
};
auto getValueList = [&](Value operand) {
SmallVector<Value> itemList;
auto sizes =
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
Torch::BaseTensorType operandType =
operand.getType().cast<Torch::BaseTensorType>();
SmallVector<int64_t> selectSizes;
selectSizes.push_back(1);
Type selectResultType = operandType.getWithSizesAndDtype(
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());
MLIRContext *context = binder.op->getContext();
for (int i = sizes[0] - 2; i < sizes[0]; i++) {
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
binder.getLoc(), selectResultType, operand, zero, selectIndex);
Value item = extract(operand, ext);
itemList.push_back(item);
}
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
Value ValueList;
if (isa<IntegerType>(xTy.getDtype())) {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::IntType::get(context)), itemList);
} else {
ValueList = rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
}
return ValueList;
};
Value scalesValueList = noneVal;
Value sizesValueList = noneVal;
Value alignCorners =
coordTfMode == "align_corners" ? cstTrue : cstFalse;
if (mode == "cubic") {
return rewriter.notifyMatchFailure(binder.op,
"unimplemented: bicubic mode");
}
if (mode == "linear") {
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
"bilinear");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizeOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizeOperand);
}
}
if (mode == "nearest") {
modeStrValue =
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
if (operands.size() < 4) {
Value scaleOperand = operands[2];
scalesValueList = getValueList(scaleOperand);
sizesValueList = noneVal;
} else {
Value sizesOperand = operands[3];
scalesValueList = noneVal;
sizesValueList = getValueList(sizesOperand);
}
}
if (scalesValueList.getType().isa<Torch::NoneType>() &&
sizesValueList.getType().isa<Torch::NoneType>()) {
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
}
rewriter
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
binder.op, resultType, operands[0], sizesValueList,
scalesValueList, modeStrValue,
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
/*Torch_BoolType:$antialias*/ cstFalse);
return success();
});
} }

View File

@ -2711,6 +2711,341 @@ public:
}; };
} // namespace } // namespace
static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH,
Value outputSizeW, Value input,
Value inputSizeH, Value inputSizeW) {
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
Value yOut = b.create<linalg::IndexOp>(loc, 2);
Value xOut = b.create<linalg::IndexOp>(loc, 3);
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
// scale = length_resized / length_original
// x_original = x_resized / scale
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value yProj = b.create<arith::DivFOp>(loc, yOutFP, hScale);
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value xProj = b.create<arith::DivFOp>(loc, xOutFP, wScale);
// get nearest pixel using floor
Value yNearestFP = b.create<math::FloorOp>(loc, yProj);
Value xNearestFP = b.create<math::FloorOp>(loc, xProj);
Value yNearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), yNearestFP);
Value yNearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yNearestInt);
Value xNearestInt =
b.create<arith::FPToSIOp>(loc, b.getI64Type(), xNearestFP);
Value xNearest =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xNearestInt);
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
int hDimOffset = 2;
indices[hDimOffset] = yNearest;
indices[hDimOffset + 1] = xNearest;
Value retVal = b.create<tensor::ExtractOp>(loc, input, indices);
return retVal;
}
static Value BilinearInterpolate(OpBuilder &b,
Aten__InterpolateSizeListScaleListOp op,
Location loc, Value outputSizeH,
Value outputSizeW, Value input,
Value inputSizeH, Value inputSizeW) {
int hDimOffset = 2;
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
Value cstOneEps = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.001));
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
Value yOut = b.create<linalg::IndexOp>(loc, 2);
Value xOut = b.create<linalg::IndexOp>(loc, 3);
bool alignCornersBool;
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
Value yProj, xProj;
if (alignCornersBool) {
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneFloat);
Value outputSizeHSubOne =
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneFloat);
Value hScale =
b.create<arith::DivFOp>(loc, inputHSubOne, outputSizeHSubOne);
Value yProjBeforeClamp = b.create<arith::MulFOp>(loc, yOutFP, hScale);
Value yMax = b.create<arith::MaximumFOp>(loc, yProjBeforeClamp, zero);
Value outputSizeHSubOneEps =
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneEps);
yProj = b.create<arith::MinimumFOp>(loc, outputSizeHSubOneEps, yMax);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneFloat);
Value outputSizeWSubOne =
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneFloat);
Value wScale =
b.create<arith::DivFOp>(loc, inputWSubOne, outputSizeWSubOne);
Value xProjBeforeClamp = b.create<arith::MulFOp>(loc, xOutFP, wScale);
Value xMax = b.create<arith::MaximumFOp>(loc, xProjBeforeClamp, zero);
Value outputSizeWSubOneEps =
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneEps);
xProj = b.create<arith::MinimumFOp>(loc, outputSizeWSubOneEps, xMax);
} else {
// y_original = (y_resized + 0.5) / scale - 0.5
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
Value outputSizeHFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
Value yPlusHalf = b.create<arith::AddFOp>(loc, yOutFP, cstHalf);
Value yDivScale = b.create<arith::DivFOp>(loc, yPlusHalf, hScale);
Value ySubHalf = b.create<arith::SubFOp>(loc, yDivScale, cstHalf);
Value yMax = b.create<arith::MaximumFOp>(loc, ySubHalf, zero);
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneEps);
yProj = b.create<arith::MinimumFOp>(loc, yMax, inputHSubOne);
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
Value outputSizeWFP =
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
Value xPlusHalf = b.create<arith::AddFOp>(loc, xOutFP, cstHalf);
Value xDivScale = b.create<arith::DivFOp>(loc, xPlusHalf, wScale);
Value xSubHalf = b.create<arith::SubFOp>(loc, xDivScale, cstHalf);
// clamp
Value xMax = b.create<arith::MaximumFOp>(loc, xSubHalf, zero);
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneEps);
xProj = b.create<arith::MinimumFOp>(loc, xMax, inputWSubOne);
}
Value yLow = b.create<math::FloorOp>(loc, yProj);
Value yProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, yProj);
Value yHigh = b.create<math::FloorOp>(loc, yProjPlusOne);
Value xLow = b.create<math::FloorOp>(loc, xProj);
Value xProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, xProj);
Value xHigh = b.create<math::FloorOp>(loc, xProjPlusOne);
SmallVector<Value> indices;
for (unsigned i = 0; i < inputRank; i++) {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
Value yLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yLow);
Value yLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yLowInt);
Value xLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xLow);
Value xLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xLowInt);
Value yHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yHigh);
Value yHighIdx =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yHighInt);
Value xHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xHigh);
Value xHighIdx =
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xHighInt);
indices[hDimOffset] = yLowIdx;
indices[hDimOffset + 1] = xLowIdx;
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yLowIdx;
indices[hDimOffset + 1] = xHighIdx;
Value p01 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yHighIdx;
indices[hDimOffset + 1] = xLowIdx;
Value p10 = b.create<tensor::ExtractOp>(loc, input, indices);
indices[hDimOffset] = yHighIdx;
indices[hDimOffset + 1] = xHighIdx;
Value p11 = b.create<tensor::ExtractOp>(loc, input, indices);
// p00 p01
// p10 p11
// (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) /
// (xhigh - xlow) * p01
Value xHighMinusxProj = b.create<arith::SubFOp>(loc, xHigh, xProj);
Value xHighMinusxLow = b.create<arith::SubFOp>(loc, xHigh, xLow);
Value w0 = b.create<arith::DivFOp>(loc, xHighMinusxProj, xHighMinusxLow);
Value lhs = b.create<arith::MulFOp>(loc, w0, p00);
Value xProjMinusxLow = b.create<arith::SubFOp>(loc, xProj, xLow);
Value w1 = b.create<arith::DivFOp>(loc, xProjMinusxLow, xHighMinusxLow);
Value rhs = b.create<arith::MulFOp>(loc, w1, p01);
Value xInter = b.create<arith::AddFOp>(loc, lhs, rhs);
// (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) /
// (xhigh - xlow) * p11
lhs = b.create<arith::MulFOp>(loc, w0, p10);
rhs = b.create<arith::MulFOp>(loc, w1, p11);
Value xInter1 = b.create<arith::AddFOp>(loc, lhs, rhs);
// (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow)
// / (yhigh - ylow) * xInter1
Value yHighMinusyProj = b.create<arith::SubFOp>(loc, yHigh, yProj);
Value yHighMinusyLow = b.create<arith::SubFOp>(loc, yHigh, yLow);
w0 = b.create<arith::DivFOp>(loc, yHighMinusyProj, yHighMinusyLow);
lhs = b.create<arith::MulFOp>(loc, w0, xInter);
Value yProjMinusyLow = b.create<arith::SubFOp>(loc, yProj, yLow);
w1 = b.create<arith::DivFOp>(loc, yProjMinusyLow, yHighMinusyLow);
rhs = b.create<arith::MulFOp>(loc, w1, xInter1);
Value retVal = b.create<arith::AddFOp>(loc, lhs, rhs);
return retVal;
}
namespace {
class ConvertInterpolateOp
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string mode;
matchPattern(op.getMode(), m_TorchConstantStr(mode));
if (mode != "bilinear" && mode != "nearest") {
return failure();
}
Location loc = op->getLoc();
Value input = adaptor.getInput();
auto inputType = input.getType().cast<RankedTensorType>();
auto inputRank = inputType.getRank();
if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
}
SmallVector<Value, 2> outputSizeIntValues;
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
SmallVector<Value, 2> ScaleFactorTorchFloat;
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
SmallVector<Value, 2> ScaleFactorFloatValues;
ScaleFactorFloatValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
Value inputSizeH = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
Value inputHFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeH);
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
ScaleFactorFloatValues[0]);
Value outputSizeH = rewriter.create<arith::MulFOp>(loc, inputHFP, scale);
Value outputH = rewriter.create<math::FloorOp>(loc, outputSizeH);
outputH =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
Value inputSizeW = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
Value inputWFP = rewriter.create<arith::SIToFPOp>(
loc, rewriter.getF32Type(), inputSizeW);
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
ScaleFactorFloatValues[1]);
Value outputSizeW = rewriter.create<arith::MulFOp>(loc, inputWFP, scale);
Value outputW = rewriter.create<math::FloorOp>(loc, outputSizeW);
outputW =
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputW);
outputSizeIntValues.push_back(outputH);
outputSizeIntValues.push_back(outputW);
} else {
SmallVector<Value, 2> outputSizeTorchInt;
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
return rewriter.notifyMatchFailure(
op, "unimplemented: the output_size is not constructed from "
"ListConstruct");
outputSizeIntValues = getTypeConvertedValues(
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
}
int hDimOffset = 2;
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
dims[hDimOffset + 1] =
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
Value outTensor = rewriter.create<tensor::EmptyOp>(
loc, getAsOpFoldResult(dims), inputType.getElementType());
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
SmallVector<utils::IteratorType> iteratorTypes(
inputRank, utils::IteratorType::parallel);
Value finalRes =
rewriter
.create<linalg::GenericOp>(
loc, outTensor.getType(), ValueRange{}, outTensor,
/*indexingMaps=*/idMap,
/*iteratorTypes=*/iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
Value outputSizeH = outputSizeIntValues[0];
Value outputSizeW = outputSizeIntValues[1];
Value inputSizeH = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
Value inputSizeW = b.create<arith::ConstantOp>(
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
Value retVal;
if (mode == "nearest") {
retVal =
NearestInterpolate(b, loc, outputSizeH, outputSizeW,
input, inputSizeH, inputSizeW);
} else if (mode == "bilinear") {
retVal = BilinearInterpolate(b, op, loc, outputSizeH,
outputSizeW, input, inputSizeH,
inputSizeW);
}
b.create<linalg::YieldOp>(loc, retVal);
})
.getResult(0);
Type newResultType =
getTypeConverter()->convertType(op.getResult().getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
return success();
}
};
} // namespace
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns, TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target) { ConversionTarget &target) {
@ -2766,4 +3101,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
patterns.add<ConvertDequantizePerChannel>(typeConverter, context); patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
target.addIllegalOp<AtenGridSamplerOp>(); target.addIllegalOp<AtenGridSamplerOp>();
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context); patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
target.addIllegalOp<Aten__InterpolateSizeListScaleListOp>();
patterns.add<ConvertInterpolateOp>(typeConverter, context);
} }

View File

@ -6654,6 +6654,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n" " %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %4 : !torch.list<int>\n" " return %4 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>, %arg3: !torch.str, %arg4: !torch.optional<bool>, %arg5: !torch.optional<bool>, %arg6: !torch.bool) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: Either size or scale_factor must be presented\"\n"
" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of size and scale_factor\"\n"
" %none = torch.constant.none\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %int2 = torch.constant.int 2\n"
" %int3 = torch.constant.int 3\n"
" %false = torch.constant.bool false\n"
" %true = torch.constant.bool true\n"
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = torch.aten.__isnot__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.list<int>) {\n"
" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
" %8 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" torch.prim.If %8 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %9 = torch.aten.__getitem__.t %7, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %10 = torch.aten.append.t %3, %9 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %11 = torch.aten.__getitem__.t %7, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %12 = torch.aten.append.t %3, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list<int>\n"
" } else {\n"
" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.list<int>) {\n"
" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<float>> -> !torch.list<float>\n"
" %10 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
" torch.prim.If %10 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
" %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n"
" %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n"
" %15 = torch.aten.append.t %3, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
" %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n"
" %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n"
" %20 = torch.aten.append.t %3, %19 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list<int>\n"
" } else {\n"
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.list<int>\n"
" }\n"
" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.list<int>\n"
" }\n"
" %6 = torch.prim.If %5#0 -> (!torch.list<int>) {\n"
" torch.prim.If.yield %5#1 : !torch.list<int>\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield %3 : !torch.list<int>\n"
" }\n"
" return %6 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
@ -10159,6 +10223,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n" " return %0#1 : !torch.int\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>, %arg3: !torch.str, %arg4: !torch.optional<bool>, %arg5: !torch.optional<bool>, %arg6: !torch.bool) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
" %none = torch.constant.none\n" " %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"

View File

@ -338,6 +338,26 @@ def atengrid_sampler〡shape(input: List[int], grid: List[int], interpolation
output = [input[0],input[1],grid[1],grid[2]] output = [input[0],input[1],grid[1],grid[2]]
return output return output
def aten__interpolatesize_list_scale_list〡shape(input: List[int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> List[int]:
output = [input[0], input[1]]
if size is not None:
assert (
scale_factor is None
), "Must specify exactly one of size and scale_factor"
output.append(size[0])
output.append(size[1])
return output
elif scale_factor is not None:
assert (
size is None
), "Must specify exactly one of size and scale_factor"
output.append(int(scale_factor[0] * input[2]))
output.append(int(scale_factor[1] * input[3]))
return output
assert 0, "Either size or scale_factor must be presented"
return output
def primscollapse〡shape(a: List[int], start: int, end: int) -> List[int]: def primscollapse〡shape(a: List[int], start: int, end: int) -> List[int]:
# Obtained through trial and error on a few examples in PyTorch: # Obtained through trial and error on a few examples in PyTorch:
assert start < len(a), "start out of bounds" assert start < len(a), "start out of bounds"
@ -2330,6 +2350,10 @@ def atengrid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dty
grid_rank, grid_dtype = input_rank_dtype grid_rank, grid_dtype = input_rank_dtype
return input_dtype return input_dtype
def aten__interpolatesize_list_scale_list〡dtype(input_rank_dtype: Tuple[int, int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> int:
input_rank, input_dtype = input_rank_dtype
return input_dtype
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), @check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),

View File

@ -634,6 +634,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit_with_mutating_variants( emit_with_mutating_variants(
"aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)" "aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)"
) )
emit(
"aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)"
)
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")

View File

@ -2006,3 +2006,24 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
%0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>) %0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>)
return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32> return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>
} }
// -----
// CHECK-LABEL: func.func @test_resize_sizes_nearest
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
// CHECK-LABEL: func.func @test_resize_sizes_linear
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
return %0 : !torch.vtensor<[?,?,?,?],f32>
}

View File

@ -0,0 +1,142 @@
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func.func @test_resize_sizes_linear
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[generic:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32
// CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32
// CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
// CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32
// CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32
// CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32
// CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32
// CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32
// CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32
// CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32
// CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32
// CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32
// CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32
// CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32
// CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32
// CHECK: %[[x43:.*]] = linalg.index 0 : index
// CHECK: %[[x44:.*]] = linalg.index 1 : index
// CHECK: %[[x45:.*]] = linalg.index 2 : index
// CHECK: %[[x46:.*]] = linalg.index 3 : index
// CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64
// CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index
// CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64
// CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index
// CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64
// CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index
// CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64
// CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32>
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32>
// CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32
// CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32
// CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32
// CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32
// CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32
// CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32
// CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32
// CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32
// CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32
// CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32
// CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32
// CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32
// CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32
// CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32
// CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32
// CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32
// CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32
// CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32
// CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%str = torch.constant.str "bilinear"
%int2 = torch.constant.int 2
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%int3 = torch.constant.int 3
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}
// -----
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[GENERIC:.*]] = linalg.generic
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
// CHECK: %[[x13:.*]] = linalg.index 2 : index
// CHECK: %[[x14:.*]] = linalg.index 3 : index
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
// CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
// CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
// CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64
// CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
// CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
// CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
// CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
// CHECK: %[[x35:.*]] = linalg.index 0 : index
// CHECK: %[[x36:.*]] = linalg.index 1 : index
// CHECK: %[[x37:.*]] = linalg.index 2 : index
// CHECK: %[[x38:.*]] = linalg.index 3 : index
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32>
// CHECK: linalg.yield %[[extracted]] : f32
%none = torch.constant.none
%none_0 = torch.constant.none
%int0 = torch.constant.int 0
%false = torch.constant.bool false
%true = torch.constant.bool true
%str = torch.constant.str "nearest"
%int2 = torch.constant.int 2
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
%int3 = torch.constant.int 3
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
return %5 : !torch.vtensor<[?,?,?,?],f32>
}