mirror of https://github.com/llvm/torch-mlir
OnnxToTorch lowering resize op (#3013)
https://github.com/nod-ai/SHARK-Turbine/issues/358 adds a lowering from onnx to linalg for bilinear and nearest resize with support for using scales or sizes to get resize shape. uses coordinate transform half pixel for bilinear mode and asymmetrical for nearest mode. See https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize. Added two passes -- one for bilinear and the other for nearest.pull/3304/head
parent
bce800a3f4
commit
ec6d7aa5d2
|
@ -7260,6 +7260,35 @@ def Torch_AtenMaskedScatter_Op : Torch_Op<"aten.masked_scatter_", [
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Torch_Aten__InterpolateSizeListScaleListOp : Torch_Op<"aten.__interpolate.size_list_scale_list", [
|
||||||
|
AllowsTypeRefinement,
|
||||||
|
HasValueSemantics,
|
||||||
|
ReadOnly
|
||||||
|
]> {
|
||||||
|
let summary = "Generated op for `aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)`";
|
||||||
|
let arguments = (ins
|
||||||
|
AnyTorchTensorType:$input,
|
||||||
|
AnyTorchOptionalListOfTorchIntType:$size,
|
||||||
|
AnyTorchOptionalListOfTorchFloatType:$scale_factor,
|
||||||
|
Torch_StringType:$mode,
|
||||||
|
AnyTorchOptionalBoolType:$align_corners,
|
||||||
|
AnyTorchOptionalBoolType:$recompute_scale_factor,
|
||||||
|
Torch_BoolType:$antialias
|
||||||
|
);
|
||||||
|
let results = (outs
|
||||||
|
AnyTorchOptionalTensorType:$result
|
||||||
|
);
|
||||||
|
let hasCustomAssemblyFormat = 1;
|
||||||
|
let extraClassDefinition = [{
|
||||||
|
ParseResult Aten__InterpolateSizeListScaleListOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||||
|
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||||
|
}
|
||||||
|
void Aten__InterpolateSizeListScaleListOp::print(OpAsmPrinter &printer) {
|
||||||
|
printDefaultTorchOp(printer, *this, 7, 1);
|
||||||
|
}
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
|
def Torch_AtenAdaptiveAvgPool1dOp : Torch_Op<"aten.adaptive_avg_pool1d", [
|
||||||
AllowsTypeRefinement,
|
AllowsTypeRefinement,
|
||||||
HasValueSemantics,
|
HasValueSemantics,
|
||||||
|
|
|
@ -2637,4 +2637,156 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
||||||
rewriter.replaceOp(binder.op, {loss, logProb});
|
rewriter.replaceOp(binder.op, {loss, logProb});
|
||||||
return success();
|
return success();
|
||||||
});
|
});
|
||||||
|
patterns.onOp(
|
||||||
|
"Resize", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||||
|
Torch::ValueTensorType resultType;
|
||||||
|
llvm::SmallVector<Value> operands;
|
||||||
|
std::string mode, nearest_mode, coordTfMode;
|
||||||
|
Value noneVal = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||||
|
|
||||||
|
if (auto attr = binder.op->getAttr("torch.onnx.antialias")) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op,
|
||||||
|
"unimplemented: support not present for antialias attribute");
|
||||||
|
}
|
||||||
|
if (auto attr = binder.op->getAttr("torch.onnx.axes")) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op,
|
||||||
|
"unimplemented: support not present for axes attribute");
|
||||||
|
}
|
||||||
|
if (auto attr = binder.op->getAttr("torch.onnx.exclude_outside")) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: support not present for "
|
||||||
|
"exclude_outside attribute");
|
||||||
|
}
|
||||||
|
if (auto attr = binder.op->getAttr("torch.onnx.extrapolation_value")) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: support not present for "
|
||||||
|
"extrapolation_value attribute");
|
||||||
|
}
|
||||||
|
if (auto attr =
|
||||||
|
binder.op->getAttr("torch.onnx.keep_aspect_ratio_policy")) {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: support not present for "
|
||||||
|
"keep_aspect_ratio_policy attribute");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (binder.tensorOperandsList(operands) ||
|
||||||
|
binder.tensorResultType(resultType) ||
|
||||||
|
binder.customOpNameStringAttr(mode, "mode", "nearest") ||
|
||||||
|
binder.customOpNameStringAttr(
|
||||||
|
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
|
||||||
|
binder.customOpNameStringAttr(nearest_mode, "nearest_mode", ""))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (mode == "nearest" && nearest_mode != "floor") {
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
binder.op, "unimplemented: support not present for nearest_mode "
|
||||||
|
"except floor");
|
||||||
|
}
|
||||||
|
|
||||||
|
Value zero = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0));
|
||||||
|
|
||||||
|
Value cstFalse =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||||
|
Value cstTrue =
|
||||||
|
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), true);
|
||||||
|
Value modeStrValue;
|
||||||
|
|
||||||
|
auto extract = [&rewriter, &binder](Value x, Value v) {
|
||||||
|
auto xTy = x.getType().cast<Torch::ValueTensorType>();
|
||||||
|
Type extractTy = rewriter.getType<Torch::FloatType>();
|
||||||
|
if (isa<IntegerType>(xTy.getDtype()))
|
||||||
|
extractTy = rewriter.getType<Torch::IntType>();
|
||||||
|
|
||||||
|
return rewriter.create<Torch::AtenItemOp>(binder.getLoc(), extractTy,
|
||||||
|
v);
|
||||||
|
};
|
||||||
|
|
||||||
|
auto getValueList = [&](Value operand) {
|
||||||
|
SmallVector<Value> itemList;
|
||||||
|
auto sizes =
|
||||||
|
dyn_cast<Torch::ValueTensorType>(operand.getType()).getSizes();
|
||||||
|
Torch::BaseTensorType operandType =
|
||||||
|
operand.getType().cast<Torch::BaseTensorType>();
|
||||||
|
|
||||||
|
SmallVector<int64_t> selectSizes;
|
||||||
|
selectSizes.push_back(1);
|
||||||
|
Type selectResultType = operandType.getWithSizesAndDtype(
|
||||||
|
llvm::ArrayRef(selectSizes), operandType.getOptionalDtype());
|
||||||
|
|
||||||
|
MLIRContext *context = binder.op->getContext();
|
||||||
|
for (int i = sizes[0] - 2; i < sizes[0]; i++) {
|
||||||
|
Value selectIndex = rewriter.create<Torch::ConstantIntOp>(
|
||||||
|
binder.getLoc(), rewriter.getType<Torch::IntType>(),
|
||||||
|
rewriter.getIntegerAttr(rewriter.getIntegerType(64), i));
|
||||||
|
Value ext = rewriter.create<Torch::AtenSelectIntOp>(
|
||||||
|
binder.getLoc(), selectResultType, operand, zero, selectIndex);
|
||||||
|
Value item = extract(operand, ext);
|
||||||
|
itemList.push_back(item);
|
||||||
|
}
|
||||||
|
auto xTy = operand.getType().cast<Torch::ValueTensorType>();
|
||||||
|
Value ValueList;
|
||||||
|
if (isa<IntegerType>(xTy.getDtype())) {
|
||||||
|
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::IntType::get(context)), itemList);
|
||||||
|
} else {
|
||||||
|
ValueList = rewriter.create<Torch::PrimListConstructOp>(
|
||||||
|
binder.getLoc(),
|
||||||
|
Torch::ListType::get(Torch::FloatType::get(context)), itemList);
|
||||||
|
}
|
||||||
|
return ValueList;
|
||||||
|
};
|
||||||
|
|
||||||
|
Value scalesValueList = noneVal;
|
||||||
|
Value sizesValueList = noneVal;
|
||||||
|
Value alignCorners =
|
||||||
|
coordTfMode == "align_corners" ? cstTrue : cstFalse;
|
||||||
|
|
||||||
|
if (mode == "cubic") {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op,
|
||||||
|
"unimplemented: bicubic mode");
|
||||||
|
}
|
||||||
|
if (mode == "linear") {
|
||||||
|
modeStrValue = rewriter.create<Torch::ConstantStrOp>(binder.getLoc(),
|
||||||
|
"bilinear");
|
||||||
|
if (operands.size() < 4) {
|
||||||
|
Value scaleOperand = operands[2];
|
||||||
|
scalesValueList = getValueList(scaleOperand);
|
||||||
|
sizesValueList = noneVal;
|
||||||
|
} else {
|
||||||
|
Value sizeOperand = operands[3];
|
||||||
|
scalesValueList = noneVal;
|
||||||
|
sizesValueList = getValueList(sizeOperand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (mode == "nearest") {
|
||||||
|
modeStrValue =
|
||||||
|
rewriter.create<Torch::ConstantStrOp>(binder.getLoc(), "nearest");
|
||||||
|
if (operands.size() < 4) {
|
||||||
|
Value scaleOperand = operands[2];
|
||||||
|
scalesValueList = getValueList(scaleOperand);
|
||||||
|
sizesValueList = noneVal;
|
||||||
|
} else {
|
||||||
|
Value sizesOperand = operands[3];
|
||||||
|
scalesValueList = noneVal;
|
||||||
|
sizesValueList = getValueList(sizesOperand);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (scalesValueList.getType().isa<Torch::NoneType>() &&
|
||||||
|
sizesValueList.getType().isa<Torch::NoneType>()) {
|
||||||
|
return rewriter.notifyMatchFailure(binder.op, "unknown scaling mode");
|
||||||
|
}
|
||||||
|
rewriter
|
||||||
|
.replaceOpWithNewOp<Torch::Aten__InterpolateSizeListScaleListOp>(
|
||||||
|
binder.op, resultType, operands[0], sizesValueList,
|
||||||
|
scalesValueList, modeStrValue,
|
||||||
|
/* AnyTorchOptionalBoolType:$align_corners */ alignCorners,
|
||||||
|
/* AnyTorchOptionalBoolType:$recompute_scale_factor */ noneVal,
|
||||||
|
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||||
|
return success();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
|
@ -2711,6 +2711,341 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
static Value NearestInterpolate(OpBuilder &b, Location loc, Value outputSizeH,
|
||||||
|
Value outputSizeW, Value input,
|
||||||
|
Value inputSizeH, Value inputSizeW) {
|
||||||
|
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
Value yOut = b.create<linalg::IndexOp>(loc, 2);
|
||||||
|
Value xOut = b.create<linalg::IndexOp>(loc, 3);
|
||||||
|
|
||||||
|
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
|
||||||
|
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
|
||||||
|
|
||||||
|
Value outputSizeHFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
|
||||||
|
Value outputSizeWFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
|
||||||
|
|
||||||
|
// scale = length_resized / length_original
|
||||||
|
// x_original = x_resized / scale
|
||||||
|
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
|
||||||
|
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
|
||||||
|
|
||||||
|
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
|
||||||
|
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
|
||||||
|
Value yProj = b.create<arith::DivFOp>(loc, yOutFP, hScale);
|
||||||
|
|
||||||
|
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
|
||||||
|
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
|
||||||
|
Value xProj = b.create<arith::DivFOp>(loc, xOutFP, wScale);
|
||||||
|
|
||||||
|
// get nearest pixel using floor
|
||||||
|
Value yNearestFP = b.create<math::FloorOp>(loc, yProj);
|
||||||
|
Value xNearestFP = b.create<math::FloorOp>(loc, xProj);
|
||||||
|
|
||||||
|
Value yNearestInt =
|
||||||
|
b.create<arith::FPToSIOp>(loc, b.getI64Type(), yNearestFP);
|
||||||
|
Value yNearest =
|
||||||
|
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yNearestInt);
|
||||||
|
|
||||||
|
Value xNearestInt =
|
||||||
|
b.create<arith::FPToSIOp>(loc, b.getI64Type(), xNearestFP);
|
||||||
|
Value xNearest =
|
||||||
|
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xNearestInt);
|
||||||
|
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (unsigned i = 0; i < inputRank; i++) {
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
|
||||||
|
int hDimOffset = 2;
|
||||||
|
indices[hDimOffset] = yNearest;
|
||||||
|
indices[hDimOffset + 1] = xNearest;
|
||||||
|
Value retVal = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Value BilinearInterpolate(OpBuilder &b,
|
||||||
|
Aten__InterpolateSizeListScaleListOp op,
|
||||||
|
Location loc, Value outputSizeH,
|
||||||
|
Value outputSizeW, Value input,
|
||||||
|
Value inputSizeH, Value inputSizeW) {
|
||||||
|
int hDimOffset = 2;
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
Value cstOneEps = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.001));
|
||||||
|
Value cstOneFloat = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(1.0));
|
||||||
|
Value cstHalf = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.5));
|
||||||
|
Value zero = b.create<arith::ConstantOp>(loc, b.getF32FloatAttr(0.0));
|
||||||
|
|
||||||
|
Value yOut = b.create<linalg::IndexOp>(loc, 2);
|
||||||
|
Value xOut = b.create<linalg::IndexOp>(loc, 3);
|
||||||
|
|
||||||
|
bool alignCornersBool;
|
||||||
|
matchPattern(op.getAlignCorners(), m_TorchConstantBool(&alignCornersBool));
|
||||||
|
|
||||||
|
Value yProj, xProj;
|
||||||
|
if (alignCornersBool) {
|
||||||
|
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
|
||||||
|
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
|
||||||
|
Value outputSizeHFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
|
||||||
|
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
|
||||||
|
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
|
||||||
|
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneFloat);
|
||||||
|
Value outputSizeHSubOne =
|
||||||
|
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneFloat);
|
||||||
|
Value hScale =
|
||||||
|
b.create<arith::DivFOp>(loc, inputHSubOne, outputSizeHSubOne);
|
||||||
|
Value yProjBeforeClamp = b.create<arith::MulFOp>(loc, yOutFP, hScale);
|
||||||
|
Value yMax = b.create<arith::MaximumFOp>(loc, yProjBeforeClamp, zero);
|
||||||
|
Value outputSizeHSubOneEps =
|
||||||
|
b.create<arith::SubFOp>(loc, outputSizeHFP, cstOneEps);
|
||||||
|
yProj = b.create<arith::MinimumFOp>(loc, outputSizeHSubOneEps, yMax);
|
||||||
|
|
||||||
|
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
|
||||||
|
Value outputSizeWFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
|
||||||
|
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
|
||||||
|
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
|
||||||
|
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneFloat);
|
||||||
|
Value outputSizeWSubOne =
|
||||||
|
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneFloat);
|
||||||
|
Value wScale =
|
||||||
|
b.create<arith::DivFOp>(loc, inputWSubOne, outputSizeWSubOne);
|
||||||
|
Value xProjBeforeClamp = b.create<arith::MulFOp>(loc, xOutFP, wScale);
|
||||||
|
Value xMax = b.create<arith::MaximumFOp>(loc, xProjBeforeClamp, zero);
|
||||||
|
Value outputSizeWSubOneEps =
|
||||||
|
b.create<arith::SubFOp>(loc, outputSizeWFP, cstOneEps);
|
||||||
|
xProj = b.create<arith::MinimumFOp>(loc, outputSizeWSubOneEps, xMax);
|
||||||
|
} else {
|
||||||
|
// y_original = (y_resized + 0.5) / scale - 0.5
|
||||||
|
Value inputHFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeH);
|
||||||
|
Value outputSizeHFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeH);
|
||||||
|
Value hScale = b.create<arith::DivFOp>(loc, outputSizeHFP, inputHFP);
|
||||||
|
Value yOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), yOut);
|
||||||
|
Value yOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), yOutInt);
|
||||||
|
Value yPlusHalf = b.create<arith::AddFOp>(loc, yOutFP, cstHalf);
|
||||||
|
Value yDivScale = b.create<arith::DivFOp>(loc, yPlusHalf, hScale);
|
||||||
|
Value ySubHalf = b.create<arith::SubFOp>(loc, yDivScale, cstHalf);
|
||||||
|
Value yMax = b.create<arith::MaximumFOp>(loc, ySubHalf, zero);
|
||||||
|
Value inputHSubOne = b.create<arith::SubFOp>(loc, inputHFP, cstOneEps);
|
||||||
|
yProj = b.create<arith::MinimumFOp>(loc, yMax, inputHSubOne);
|
||||||
|
|
||||||
|
Value inputWFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), inputSizeW);
|
||||||
|
Value outputSizeWFP =
|
||||||
|
b.create<arith::SIToFPOp>(loc, b.getF32Type(), outputSizeW);
|
||||||
|
Value wScale = b.create<arith::DivFOp>(loc, outputSizeWFP, inputWFP);
|
||||||
|
Value xOutInt = b.create<arith::IndexCastOp>(loc, b.getI64Type(), xOut);
|
||||||
|
Value xOutFP = b.create<arith::SIToFPOp>(loc, b.getF32Type(), xOutInt);
|
||||||
|
Value xPlusHalf = b.create<arith::AddFOp>(loc, xOutFP, cstHalf);
|
||||||
|
Value xDivScale = b.create<arith::DivFOp>(loc, xPlusHalf, wScale);
|
||||||
|
Value xSubHalf = b.create<arith::SubFOp>(loc, xDivScale, cstHalf);
|
||||||
|
// clamp
|
||||||
|
Value xMax = b.create<arith::MaximumFOp>(loc, xSubHalf, zero);
|
||||||
|
Value inputWSubOne = b.create<arith::SubFOp>(loc, inputWFP, cstOneEps);
|
||||||
|
xProj = b.create<arith::MinimumFOp>(loc, xMax, inputWSubOne);
|
||||||
|
}
|
||||||
|
Value yLow = b.create<math::FloorOp>(loc, yProj);
|
||||||
|
Value yProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, yProj);
|
||||||
|
Value yHigh = b.create<math::FloorOp>(loc, yProjPlusOne);
|
||||||
|
|
||||||
|
Value xLow = b.create<math::FloorOp>(loc, xProj);
|
||||||
|
Value xProjPlusOne = b.create<arith::AddFOp>(loc, cstOneFloat, xProj);
|
||||||
|
Value xHigh = b.create<math::FloorOp>(loc, xProjPlusOne);
|
||||||
|
|
||||||
|
SmallVector<Value> indices;
|
||||||
|
for (unsigned i = 0; i < inputRank; i++) {
|
||||||
|
indices.push_back(b.create<linalg::IndexOp>(loc, i));
|
||||||
|
}
|
||||||
|
Value yLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yLow);
|
||||||
|
Value yLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), yLowInt);
|
||||||
|
|
||||||
|
Value xLowInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xLow);
|
||||||
|
Value xLowIdx = b.create<arith::IndexCastOp>(loc, b.getIndexType(), xLowInt);
|
||||||
|
|
||||||
|
Value yHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), yHigh);
|
||||||
|
Value yHighIdx =
|
||||||
|
b.create<arith::IndexCastOp>(loc, b.getIndexType(), yHighInt);
|
||||||
|
|
||||||
|
Value xHighInt = b.create<arith::FPToSIOp>(loc, b.getI64Type(), xHigh);
|
||||||
|
Value xHighIdx =
|
||||||
|
b.create<arith::IndexCastOp>(loc, b.getIndexType(), xHighInt);
|
||||||
|
|
||||||
|
indices[hDimOffset] = yLowIdx;
|
||||||
|
indices[hDimOffset + 1] = xLowIdx;
|
||||||
|
Value p00 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
|
||||||
|
indices[hDimOffset] = yLowIdx;
|
||||||
|
indices[hDimOffset + 1] = xHighIdx;
|
||||||
|
Value p01 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
|
||||||
|
indices[hDimOffset] = yHighIdx;
|
||||||
|
indices[hDimOffset + 1] = xLowIdx;
|
||||||
|
Value p10 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
|
||||||
|
indices[hDimOffset] = yHighIdx;
|
||||||
|
indices[hDimOffset + 1] = xHighIdx;
|
||||||
|
Value p11 = b.create<tensor::ExtractOp>(loc, input, indices);
|
||||||
|
|
||||||
|
// p00 p01
|
||||||
|
// p10 p11
|
||||||
|
// (xhigh - xproj) / (xhigh - xlow) * p00 + (xproj - xlow) /
|
||||||
|
// (xhigh - xlow) * p01
|
||||||
|
Value xHighMinusxProj = b.create<arith::SubFOp>(loc, xHigh, xProj);
|
||||||
|
Value xHighMinusxLow = b.create<arith::SubFOp>(loc, xHigh, xLow);
|
||||||
|
Value w0 = b.create<arith::DivFOp>(loc, xHighMinusxProj, xHighMinusxLow);
|
||||||
|
Value lhs = b.create<arith::MulFOp>(loc, w0, p00);
|
||||||
|
|
||||||
|
Value xProjMinusxLow = b.create<arith::SubFOp>(loc, xProj, xLow);
|
||||||
|
Value w1 = b.create<arith::DivFOp>(loc, xProjMinusxLow, xHighMinusxLow);
|
||||||
|
Value rhs = b.create<arith::MulFOp>(loc, w1, p01);
|
||||||
|
|
||||||
|
Value xInter = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||||
|
|
||||||
|
// (xhigh - xproj) / (xhigh - xlow) * p10 + (xproj - xlow) /
|
||||||
|
// (xhigh - xlow) * p11
|
||||||
|
lhs = b.create<arith::MulFOp>(loc, w0, p10);
|
||||||
|
rhs = b.create<arith::MulFOp>(loc, w1, p11);
|
||||||
|
|
||||||
|
Value xInter1 = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||||
|
|
||||||
|
// (yhigh - yproj) / (yhigh - ylow) * xInter + (yproj - ylow)
|
||||||
|
// / (yhigh - ylow) * xInter1
|
||||||
|
Value yHighMinusyProj = b.create<arith::SubFOp>(loc, yHigh, yProj);
|
||||||
|
Value yHighMinusyLow = b.create<arith::SubFOp>(loc, yHigh, yLow);
|
||||||
|
w0 = b.create<arith::DivFOp>(loc, yHighMinusyProj, yHighMinusyLow);
|
||||||
|
lhs = b.create<arith::MulFOp>(loc, w0, xInter);
|
||||||
|
|
||||||
|
Value yProjMinusyLow = b.create<arith::SubFOp>(loc, yProj, yLow);
|
||||||
|
w1 = b.create<arith::DivFOp>(loc, yProjMinusyLow, yHighMinusyLow);
|
||||||
|
rhs = b.create<arith::MulFOp>(loc, w1, xInter1);
|
||||||
|
|
||||||
|
Value retVal = b.create<arith::AddFOp>(loc, lhs, rhs);
|
||||||
|
|
||||||
|
return retVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class ConvertInterpolateOp
|
||||||
|
: public OpConversionPattern<Aten__InterpolateSizeListScaleListOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(Aten__InterpolateSizeListScaleListOp op, OpAdaptor adaptor,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
|
||||||
|
std::string mode;
|
||||||
|
matchPattern(op.getMode(), m_TorchConstantStr(mode));
|
||||||
|
if (mode != "bilinear" && mode != "nearest") {
|
||||||
|
return failure();
|
||||||
|
}
|
||||||
|
|
||||||
|
Location loc = op->getLoc();
|
||||||
|
Value input = adaptor.getInput();
|
||||||
|
auto inputType = input.getType().cast<RankedTensorType>();
|
||||||
|
auto inputRank = inputType.getRank();
|
||||||
|
|
||||||
|
if (inputType.isDynamicDim(2) || inputType.isDynamicDim(3)) {
|
||||||
|
return rewriter.notifyMatchFailure(op, "error: Dynamic dim on resize op");
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<Value, 2> outputSizeIntValues;
|
||||||
|
|
||||||
|
if (!op.getScaleFactor().getType().isa<Torch::NoneType>()) {
|
||||||
|
SmallVector<Value, 2> ScaleFactorTorchFloat;
|
||||||
|
if (!getListConstructElements(op.getScaleFactor(), ScaleFactorTorchFloat))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the output_size is not constructed from "
|
||||||
|
"ListConstruct");
|
||||||
|
SmallVector<Value, 2> ScaleFactorFloatValues;
|
||||||
|
ScaleFactorFloatValues = getTypeConvertedValues(
|
||||||
|
rewriter, loc, getTypeConverter(), ScaleFactorTorchFloat);
|
||||||
|
Value inputSizeH = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(inputType.getShape()[2]));
|
||||||
|
Value inputHFP = rewriter.create<arith::SIToFPOp>(
|
||||||
|
loc, rewriter.getF32Type(), inputSizeH);
|
||||||
|
Value scale = rewriter.create<arith::TruncFOp>(loc, inputHFP.getType(),
|
||||||
|
ScaleFactorFloatValues[0]);
|
||||||
|
Value outputSizeH = rewriter.create<arith::MulFOp>(loc, inputHFP, scale);
|
||||||
|
Value outputH = rewriter.create<math::FloorOp>(loc, outputSizeH);
|
||||||
|
outputH =
|
||||||
|
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputH);
|
||||||
|
|
||||||
|
Value inputSizeW = rewriter.create<arith::ConstantOp>(
|
||||||
|
loc, rewriter.getI64IntegerAttr(inputType.getShape()[3]));
|
||||||
|
Value inputWFP = rewriter.create<arith::SIToFPOp>(
|
||||||
|
loc, rewriter.getF32Type(), inputSizeW);
|
||||||
|
scale = rewriter.create<arith::TruncFOp>(loc, inputWFP.getType(),
|
||||||
|
ScaleFactorFloatValues[1]);
|
||||||
|
Value outputSizeW = rewriter.create<arith::MulFOp>(loc, inputWFP, scale);
|
||||||
|
Value outputW = rewriter.create<math::FloorOp>(loc, outputSizeW);
|
||||||
|
outputW =
|
||||||
|
rewriter.create<arith::FPToSIOp>(loc, rewriter.getI64Type(), outputW);
|
||||||
|
|
||||||
|
outputSizeIntValues.push_back(outputH);
|
||||||
|
outputSizeIntValues.push_back(outputW);
|
||||||
|
} else {
|
||||||
|
SmallVector<Value, 2> outputSizeTorchInt;
|
||||||
|
if (!getListConstructElements(op.getSize(), outputSizeTorchInt))
|
||||||
|
return rewriter.notifyMatchFailure(
|
||||||
|
op, "unimplemented: the output_size is not constructed from "
|
||||||
|
"ListConstruct");
|
||||||
|
outputSizeIntValues = getTypeConvertedValues(
|
||||||
|
rewriter, loc, getTypeConverter(), outputSizeTorchInt);
|
||||||
|
}
|
||||||
|
int hDimOffset = 2;
|
||||||
|
SmallVector<Value> dims = getTensorSizes(rewriter, loc, input);
|
||||||
|
dims[hDimOffset] = castIntToIndex(rewriter, loc, outputSizeIntValues[0]);
|
||||||
|
dims[hDimOffset + 1] =
|
||||||
|
castIntToIndex(rewriter, loc, outputSizeIntValues[1]);
|
||||||
|
|
||||||
|
Value outTensor = rewriter.create<tensor::EmptyOp>(
|
||||||
|
loc, getAsOpFoldResult(dims), inputType.getElementType());
|
||||||
|
|
||||||
|
AffineMap idMap = rewriter.getMultiDimIdentityMap(inputRank);
|
||||||
|
|
||||||
|
SmallVector<utils::IteratorType> iteratorTypes(
|
||||||
|
inputRank, utils::IteratorType::parallel);
|
||||||
|
|
||||||
|
Value finalRes =
|
||||||
|
rewriter
|
||||||
|
.create<linalg::GenericOp>(
|
||||||
|
loc, outTensor.getType(), ValueRange{}, outTensor,
|
||||||
|
/*indexingMaps=*/idMap,
|
||||||
|
/*iteratorTypes=*/iteratorTypes,
|
||||||
|
[&](OpBuilder &b, Location loc, ValueRange args) {
|
||||||
|
Value outputSizeH = outputSizeIntValues[0];
|
||||||
|
Value outputSizeW = outputSizeIntValues[1];
|
||||||
|
Value inputSizeH = b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getI64IntegerAttr(inputType.getShape()[2]));
|
||||||
|
Value inputSizeW = b.create<arith::ConstantOp>(
|
||||||
|
loc, b.getI64IntegerAttr(inputType.getShape()[3]));
|
||||||
|
Value retVal;
|
||||||
|
if (mode == "nearest") {
|
||||||
|
retVal =
|
||||||
|
NearestInterpolate(b, loc, outputSizeH, outputSizeW,
|
||||||
|
input, inputSizeH, inputSizeW);
|
||||||
|
} else if (mode == "bilinear") {
|
||||||
|
retVal = BilinearInterpolate(b, op, loc, outputSizeH,
|
||||||
|
outputSizeW, input, inputSizeH,
|
||||||
|
inputSizeW);
|
||||||
|
}
|
||||||
|
b.create<linalg::YieldOp>(loc, retVal);
|
||||||
|
})
|
||||||
|
.getResult(0);
|
||||||
|
Type newResultType =
|
||||||
|
getTypeConverter()->convertType(op.getResult().getType());
|
||||||
|
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, finalRes);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
TypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||||
ConversionTarget &target) {
|
ConversionTarget &target) {
|
||||||
|
@ -2766,4 +3101,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality(
|
||||||
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
|
patterns.add<ConvertDequantizePerChannel>(typeConverter, context);
|
||||||
target.addIllegalOp<AtenGridSamplerOp>();
|
target.addIllegalOp<AtenGridSamplerOp>();
|
||||||
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
|
patterns.add<ConvertAtenGridSamplerOp>(typeConverter, context);
|
||||||
|
target.addIllegalOp<Aten__InterpolateSizeListScaleListOp>();
|
||||||
|
patterns.add<ConvertInterpolateOp>(typeConverter, context);
|
||||||
}
|
}
|
||||||
|
|
|
@ -6654,6 +6654,70 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
" return %4 : !torch.list<int>\n"
|
" return %4 : !torch.list<int>\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_shape_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.list<int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>, %arg3: !torch.str, %arg4: !torch.optional<bool>, %arg5: !torch.optional<bool>, %arg6: !torch.bool) -> !torch.list<int> {\n"
|
||||||
|
" %str = torch.constant.str \"AssertionError: Either size or scale_factor must be presented\"\n"
|
||||||
|
" %str_0 = torch.constant.str \"AssertionError: Must specify exactly one of size and scale_factor\"\n"
|
||||||
|
" %none = torch.constant.none\n"
|
||||||
|
" %int0 = torch.constant.int 0\n"
|
||||||
|
" %int1 = torch.constant.int 1\n"
|
||||||
|
" %int2 = torch.constant.int 2\n"
|
||||||
|
" %int3 = torch.constant.int 3\n"
|
||||||
|
" %false = torch.constant.bool false\n"
|
||||||
|
" %true = torch.constant.bool true\n"
|
||||||
|
" %0 = torch.prim.Uninitialized : !torch.list<int>\n"
|
||||||
|
" %1 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %2 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %3 = torch.prim.ListConstruct %1, %2 : (!torch.int, !torch.int) -> !torch.list<int>\n"
|
||||||
|
" %4 = torch.aten.__isnot__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||||
|
" %5:2 = torch.prim.If %4 -> (!torch.bool, !torch.list<int>) {\n"
|
||||||
|
" %7 = torch.prim.unchecked_cast %arg1 : !torch.optional<list<int>> -> !torch.list<int>\n"
|
||||||
|
" %8 = torch.aten.__is__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||||
|
" torch.prim.If %8 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %9 = torch.aten.__getitem__.t %7, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %10 = torch.aten.append.t %3, %9 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" %11 = torch.aten.__getitem__.t %7, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %12 = torch.aten.append.t %3, %11 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" %7 = torch.aten.__isnot__ %arg2, %none : !torch.optional<list<float>>, !torch.none -> !torch.bool\n"
|
||||||
|
" %8:2 = torch.prim.If %7 -> (!torch.bool, !torch.list<int>) {\n"
|
||||||
|
" %9 = torch.prim.unchecked_cast %arg2 : !torch.optional<list<float>> -> !torch.list<float>\n"
|
||||||
|
" %10 = torch.aten.__is__ %arg1, %none : !torch.optional<list<int>>, !torch.none -> !torch.bool\n"
|
||||||
|
" torch.prim.If %10 -> () {\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield\n"
|
||||||
|
" }\n"
|
||||||
|
" %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||||
|
" %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n"
|
||||||
|
" %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n"
|
||||||
|
" %15 = torch.aten.append.t %3, %14 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list<float>, !torch.int -> !torch.float\n"
|
||||||
|
" %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||||
|
" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n"
|
||||||
|
" %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n"
|
||||||
|
" %20 = torch.aten.append.t %3, %19 : !torch.list<int>, !torch.int -> !torch.list<int>\n"
|
||||||
|
" torch.prim.If.yield %true, %3 : !torch.bool, !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.If.yield %false, %0 : !torch.bool, !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" torch.prim.If.yield %8#0, %8#1 : !torch.bool, !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" %6 = torch.prim.If %5#0 -> (!torch.list<int>) {\n"
|
||||||
|
" torch.prim.If.yield %5#1 : !torch.list<int>\n"
|
||||||
|
" } else {\n"
|
||||||
|
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
|
||||||
|
" torch.prim.If.yield %3 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
|
" return %6 : !torch.list<int>\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
|
" func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list<int> {\n"
|
||||||
" %true = torch.constant.bool true\n"
|
" %true = torch.constant.bool true\n"
|
||||||
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
|
" %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n"
|
||||||
|
@ -10159,6 +10223,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
||||||
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
" return %0#1 : !torch.int\n"
|
" return %0#1 : !torch.int\n"
|
||||||
" }\n"
|
" }\n"
|
||||||
|
" func.func @\"__torch_mlir_dtype_fn.aten.__interpolate.size_list_scale_list\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.optional<list<int>>, %arg2: !torch.optional<list<float>>, %arg3: !torch.str, %arg4: !torch.optional<bool>, %arg5: !torch.optional<bool>, %arg6: !torch.bool) -> !torch.int {\n"
|
||||||
|
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
|
||||||
|
" return %0#1 : !torch.int\n"
|
||||||
|
" }\n"
|
||||||
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>) -> !torch.int {\n"
|
||||||
" %none = torch.constant.none\n"
|
" %none = torch.constant.none\n"
|
||||||
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
|
" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n"
|
||||||
|
|
|
@ -338,6 +338,26 @@ def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation
|
||||||
output = [input[0],input[1],grid[1],grid[2]]
|
output = [input[0],input[1],grid[1],grid[2]]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def aten〇__interpolate〇size_list_scale_list〡shape(input: List[int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> List[int]:
|
||||||
|
output = [input[0], input[1]]
|
||||||
|
if size is not None:
|
||||||
|
assert (
|
||||||
|
scale_factor is None
|
||||||
|
), "Must specify exactly one of size and scale_factor"
|
||||||
|
output.append(size[0])
|
||||||
|
output.append(size[1])
|
||||||
|
return output
|
||||||
|
elif scale_factor is not None:
|
||||||
|
assert (
|
||||||
|
size is None
|
||||||
|
), "Must specify exactly one of size and scale_factor"
|
||||||
|
output.append(int(scale_factor[0] * input[2]))
|
||||||
|
output.append(int(scale_factor[1] * input[3]))
|
||||||
|
return output
|
||||||
|
assert 0, "Either size or scale_factor must be presented"
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
|
def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
|
||||||
# Obtained through trial and error on a few examples in PyTorch:
|
# Obtained through trial and error on a few examples in PyTorch:
|
||||||
assert start < len(a), "start out of bounds"
|
assert start < len(a), "start out of bounds"
|
||||||
|
@ -2330,6 +2350,10 @@ def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dty
|
||||||
grid_rank, grid_dtype = input_rank_dtype
|
grid_rank, grid_dtype = input_rank_dtype
|
||||||
return input_dtype
|
return input_dtype
|
||||||
|
|
||||||
|
def aten〇__interpolate〇size_list_scale_list〡dtype(input_rank_dtype: Tuple[int, int], size: Optional[List[int]] = None, scale_factor: Optional[List[float]] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False) -> int:
|
||||||
|
input_rank, input_dtype = input_rank_dtype
|
||||||
|
return input_dtype
|
||||||
|
|
||||||
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
|
@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1),
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
|
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]),
|
||||||
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
|
ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]),
|
||||||
|
|
|
@ -634,6 +634,9 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
||||||
emit_with_mutating_variants(
|
emit_with_mutating_variants(
|
||||||
"aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)"
|
"aten::masked_scatter : (Tensor, Tensor, Tensor) -> (Tensor)"
|
||||||
)
|
)
|
||||||
|
emit(
|
||||||
|
"aten::__interpolate.size_list_scale_list : (Tensor, int[]?, float[]?, str, bool?, bool?, bool) -> (Tensor)"
|
||||||
|
)
|
||||||
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::adaptive_avg_pool1d : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||||
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
emit("aten::_adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)")
|
||||||
|
|
|
@ -2006,3 +2006,24 @@ func.func @test_sce_mean_3d_log_prob(%arg0: !torch.vtensor<[3,5,2],f32>, %arg1:
|
||||||
%0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>)
|
%0:2 = torch.operator "onnx.SoftmaxCrossEntropyLoss"(%arg0, %arg1) {torch.onnx.reduction = "mean"} : (!torch.vtensor<[3,5,2],f32>, !torch.vtensor<[3,2],si64>) -> (!torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>)
|
||||||
return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>
|
return %0#0, %0#1 : !torch.vtensor<[],f32>, !torch.vtensor<[3,5,2],f32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_resize_sizes_nearest
|
||||||
|
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.coordinate_transformation_mode = "asymmetric", torch.onnx.cubic_coeff_a = -7.500000e-01 : f32, torch.onnx.mode = "nearest", torch.onnx.nearest_mode = "floor"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_resize_sizes_linear
|
||||||
|
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],
|
||||||
|
f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
%none = torch.constant.none
|
||||||
|
// CHECK: torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
%0 = torch.operator "onnx.Resize"(%arg0, %none, %none, %arg1) {torch.onnx.mode = "linear"} : (!torch.vtensor<[1,1,2,4],f32>, !torch.none, !torch.none, !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %0 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
|
@ -0,0 +1,142 @@
|
||||||
|
// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: func.func @test_resize_sizes_linear
|
||||||
|
func.func @test_resize_sizes_linear(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4]
|
||||||
|
,si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[generic:.*]] = linalg.generic
|
||||||
|
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
||||||
|
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
||||||
|
// CHECK: %[[cst:.*]] = arith.constant 1.001000e+00 : f32
|
||||||
|
// CHECK: %[[cst_4:.*]] = arith.constant 1.000000e+00 : f32
|
||||||
|
// CHECK: %[[cst_5:.*]] = arith.constant 5.000000e-01 : f32
|
||||||
|
// CHECK: %[[cst_6:.*]] = arith.constant 0.000000e+00 : f32
|
||||||
|
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||||
|
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||||
|
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
||||||
|
// CHECK: %[[x16:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||||
|
// CHECK: %[[x17:.*]] = arith.divf %[[x16]], %[[x15]] : f32
|
||||||
|
// CHECK: %[[x18:.*]] = arith.index_cast %[[x13]] : index to i64
|
||||||
|
// CHECK: %[[x19:.*]] = arith.sitofp %[[x18]] : i64 to f32
|
||||||
|
// CHECK: %[[x20:.*]] = arith.addf %[[x19]], %[[cst_5]] : f32
|
||||||
|
// CHECK: %[[x21:.*]] = arith.divf %[[x20]], %[[x17]] : f32
|
||||||
|
// CHECK: %[[x22:.*]] = arith.subf %[[x21]], %[[cst_5]] : f32
|
||||||
|
// CHECK: %[[x23:.*]] = arith.maximumf %[[x22]], %[[cst_6]] : f32
|
||||||
|
// CHECK: %[[x24:.*]] = arith.subf %[[x15]], %[[cst]] : f32
|
||||||
|
// CHECK: %[[x25:.*]] = arith.minimumf %[[x23]], %[[x24]] : f32
|
||||||
|
// CHECK: %[[x26:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
||||||
|
// CHECK: %[[x27:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||||
|
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x26]] : f32
|
||||||
|
// CHECK: %[[x29:.*]] = arith.index_cast %[[x14]] : index to i64
|
||||||
|
// CHECK: %[[x30:.*]] = arith.sitofp %[[x29]] : i64 to f32
|
||||||
|
// CHECK: %[[x31:.*]] = arith.addf %[[x30]], %[[cst_5]] : f32
|
||||||
|
// CHECK: %[[x32:.*]] = arith.divf %[[x31]], %[[x28]] : f32
|
||||||
|
// CHECK: %[[x33:.*]] = arith.subf %[[x32]], %[[cst_5]] : f32
|
||||||
|
// CHECK: %[[x34:.*]] = arith.maximumf %[[x33]], %[[cst_6]] : f32
|
||||||
|
// CHECK: %[[x35:.*]] = arith.subf %[[x26]], %[[cst]] : f32
|
||||||
|
// CHECK: %[[x36:.*]] = arith.minimumf %[[x34]], %[[x35]] : f32
|
||||||
|
// CHECK: %[[x37:.*]] = math.floor %[[x25]] : f32
|
||||||
|
// CHECK: %[[x38:.*]] = arith.addf %[[cst_4]], %[[x25]] : f32
|
||||||
|
// CHECK: %[[x39:.*]] = math.floor %[[x38]] : f32
|
||||||
|
// CHECK: %[[x40:.*]] = math.floor %[[x36]] : f32
|
||||||
|
// CHECK: %[[x41:.*]] = arith.addf %[[cst_4]], %[[x36]] : f32
|
||||||
|
// CHECK: %[[x42:.*]] = math.floor %[[x41]] : f32
|
||||||
|
// CHECK: %[[x43:.*]] = linalg.index 0 : index
|
||||||
|
// CHECK: %[[x44:.*]] = linalg.index 1 : index
|
||||||
|
// CHECK: %[[x45:.*]] = linalg.index 2 : index
|
||||||
|
// CHECK: %[[x46:.*]] = linalg.index 3 : index
|
||||||
|
// CHECK: %[[x47:.*]] = arith.fptosi %[[x37]] : f32 to i64
|
||||||
|
// CHECK: %[[x48:.*]] = arith.index_cast %[[x47]] : i64 to index
|
||||||
|
// CHECK: %[[x49:.*]] = arith.fptosi %[[x40]] : f32 to i64
|
||||||
|
// CHECK: %[[x50:.*]] = arith.index_cast %[[x49]] : i64 to index
|
||||||
|
// CHECK: %[[x51:.*]] = arith.fptosi %[[x39]] : f32 to i64
|
||||||
|
// CHECK: %[[x52:.*]] = arith.index_cast %[[x51]] : i64 to index
|
||||||
|
// CHECK: %[[x53:.*]] = arith.fptosi %[[x42]] : f32 to i64
|
||||||
|
// CHECK: %[[x54:.*]] = arith.index_cast %[[x53]] : i64 to index
|
||||||
|
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x43]], %[[x44]], %[[x48]], %[[x50]]] : tensor<1x1x2x4xf32>
|
||||||
|
// CHECK: %[[extracted_7:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x48]], %[[x54]]] : tensor<1x1x2x4xf32>
|
||||||
|
// CHECK: %[[extracted_8:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x50]]] : tensor<1x1x2x4xf32>
|
||||||
|
// CHECK: %[[extracted_9:.*]] = tensor.extract %[[x0]][%[[x43]], %[[x44]], %[[x52]], %[[x54]]] : tensor<1x1x2x4xf32>
|
||||||
|
// CHECK: %[[x55:.*]] = arith.subf %[[x42]], %[[x36]] : f32
|
||||||
|
// CHECK: %[[x56:.*]] = arith.subf %[[x42]], %[[x40]] : f32
|
||||||
|
// CHECK: %[[x57:.*]] = arith.divf %[[x55]], %[[x56]] : f32
|
||||||
|
// CHECK: %[[x58:.*]] = arith.mulf %[[x57]], %extracted : f32
|
||||||
|
// CHECK: %[[x59:.*]] = arith.subf %[[x36]], %[[x40]] : f32
|
||||||
|
// CHECK: %[[x60:.*]] = arith.divf %[[x59]], %[[x56]] : f32
|
||||||
|
// CHECK: %[[x61:.*]] = arith.mulf %[[x60]], %[[extracted_7]] : f32
|
||||||
|
// CHECK: %[[x62:.*]] = arith.addf %[[x58]], %[[x61]] : f32
|
||||||
|
// CHECK: %[[x63:.*]] = arith.mulf %[[x57]], %[[extracted_8]] : f32
|
||||||
|
// CHECK: %[[x64:.*]] = arith.mulf %[[x60]], %[[extracted_9]] : f32
|
||||||
|
// CHECK: %[[x65:.*]] = arith.addf %[[x63]], %[[x64]] : f32
|
||||||
|
// CHECK: %[[x66:.*]] = arith.subf %[[x39]], %[[x25]] : f32
|
||||||
|
// CHECK: %[[x67:.*]] = arith.subf %[[x39]], %[[x37]] : f32
|
||||||
|
// CHECK: %[[x68:.*]] = arith.divf %[[x66]], %[[x67]] : f32
|
||||||
|
// CHECK: %[[x69:.*]] = arith.mulf %[[x68]], %[[x62]] : f32
|
||||||
|
// CHECK: %[[x70:.*]] = arith.subf %[[x25]], %[[x37]] : f32
|
||||||
|
// CHECK: %[[x71:.*]] = arith.divf %[[x70]], %[[x67]] : f32
|
||||||
|
// CHECK: %[[x72:.*]] = arith.mulf %[[x71]], %[[x65]] : f32
|
||||||
|
// CHECK: %[[x73:.*]] = arith.addf %[[x69]], %[[x72]] : f32
|
||||||
|
%none = torch.constant.none
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%str = torch.constant.str "bilinear"
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %5 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @test_resize_sizes_nearest(%arg0: !torch.vtensor<[1,1,2,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||||
|
// CHECK: %[[GENERIC:.*]] = linalg.generic
|
||||||
|
// CHECK: %[[c2_i64:.*]] = arith.constant 2 : i64
|
||||||
|
// CHECK: %[[c4_i64:.*]] = arith.constant 4 : i64
|
||||||
|
// CHECK: %[[x13:.*]] = linalg.index 2 : index
|
||||||
|
// CHECK: %[[x14:.*]] = linalg.index 3 : index
|
||||||
|
// CHECK: %[[x15:.*]] = arith.sitofp %[[c2_i64]] : i64 to f32
|
||||||
|
// CHECK: %[[x16:.*]] = arith.sitofp %[[c4_i64]] : i64 to f32
|
||||||
|
// CHECK: %[[x19:.*]] = arith.sitofp %[[x6:.*]] : i64 to f32
|
||||||
|
// CHECK: %[[x20:.*]] = arith.sitofp %[[x7:.*]] : i64 to f32
|
||||||
|
// CHECK: %[[x21:.*]] = arith.divf %[[x19]], %[[x15]] : f32
|
||||||
|
// CHECK: %[[x22:.*]] = arith.divf %[[x20]], %[[x16]] : f32
|
||||||
|
// CHECK: %[[x23:.*]] = arith.index_cast %[[x13]] : index to i64
|
||||||
|
// CHECK: %[[x24:.*]] = arith.sitofp %[[x23]] : i64 to f32
|
||||||
|
// CHECK: %[[x25:.*]] = arith.divf %[[x24]], %[[x21]] : f32
|
||||||
|
// CHECK: %[[x26:.*]] = arith.index_cast %[[x14]] : index to i64
|
||||||
|
// CHECK: %[[x27:.*]] = arith.sitofp %[[x26]] : i64 to f32
|
||||||
|
// CHECK: %[[x28:.*]] = arith.divf %[[x27]], %[[x22]] : f32
|
||||||
|
// CHECK: %[[x29:.*]] = math.floor %[[x25]] : f32
|
||||||
|
// CHECK: %[[x30:.*]] = math.floor %[[x28]] : f32
|
||||||
|
// CHECK: %[[x31:.*]] = arith.fptosi %[[x29]] : f32 to i64
|
||||||
|
// CHECK: %[[x32:.*]] = arith.index_cast %[[x31]] : i64 to index
|
||||||
|
// CHECK: %[[x33:.*]] = arith.fptosi %[[x30]] : f32 to i64
|
||||||
|
// CHECK: %[[x34:.*]] = arith.index_cast %[[x33]] : i64 to index
|
||||||
|
// CHECK: %[[x35:.*]] = linalg.index 0 : index
|
||||||
|
// CHECK: %[[x36:.*]] = linalg.index 1 : index
|
||||||
|
// CHECK: %[[x37:.*]] = linalg.index 2 : index
|
||||||
|
// CHECK: %[[x38:.*]] = linalg.index 3 : index
|
||||||
|
// CHECK: %[[extracted:.*]] = tensor.extract %[[x0:.*]][%[[x35]], %[[x36]], %[[x32]], %[[x34]]] : tensor<1x1x2x4xf32>
|
||||||
|
// CHECK: linalg.yield %[[extracted]] : f32
|
||||||
|
%none = torch.constant.none
|
||||||
|
%none_0 = torch.constant.none
|
||||||
|
%int0 = torch.constant.int 0
|
||||||
|
%false = torch.constant.bool false
|
||||||
|
%true = torch.constant.bool true
|
||||||
|
%str = torch.constant.str "nearest"
|
||||||
|
%int2 = torch.constant.int 2
|
||||||
|
%0 = torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%1 = torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%int3 = torch.constant.int 3
|
||||||
|
%2 = torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64>
|
||||||
|
%3 = torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int
|
||||||
|
%4 = torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
%5 = torch.aten.__interpolate.size_list_scale_list %arg0, %4, %none_0, %str, %false, %none_0, %false : !torch.vtensor<[1,1,2,4],f32>, !torch.list<int>, !torch.none, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
return %5 : !torch.vtensor<[?,?,?,?],f32>
|
||||||
|
}
|
Loading…
Reference in New Issue