mirror of https://github.com/llvm/torch-mlir
[ONNX] Add basic support for RoiAlign (#3493)
This adds an onnx->torch conversion for onnx.RoiAlign into torchvision.roi_align or torchvision.roi_pool, and adds those two torchvision ops to torch-mlir.pull/3495/head
parent
02340408b7
commit
e346c911f7
|
@ -16660,3 +16660,60 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$rois,
|
||||
Torch_FloatType:$spatial_scale,
|
||||
Torch_IntType:$pooled_height,
|
||||
Torch_IntType:$pooled_width,
|
||||
Torch_IntType:$sampling_ratio,
|
||||
Torch_BoolType:$aligned
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 7, 1);
|
||||
}
|
||||
void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 7, 1);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [
|
||||
AllowsTypeRefinement,
|
||||
HasValueSemantics,
|
||||
ReadOnly
|
||||
]> {
|
||||
let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$rois,
|
||||
Torch_FloatType:$spatial_scale,
|
||||
Torch_IntType:$pooled_height,
|
||||
Torch_IntType:$pooled_width
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchOptionalTensorType:$result0,
|
||||
AnyTorchOptionalTensorType:$result1
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) {
|
||||
return parseDefaultTorchOp(parser, result, 5, 2);
|
||||
}
|
||||
void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) {
|
||||
printDefaultTorchOp(printer, *this, 5, 2);
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
|
|
|
@ -2953,6 +2953,104 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
|
|||
/*Torch_BoolType:$antialias*/ cstFalse);
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
// operands = input, rois, batch_indices
|
||||
SmallVector<Value> operands;
|
||||
std::string coordTfMode, mode;
|
||||
int64_t outHInt, outWInt, samplingRatioInt;
|
||||
float spatialScaleFloat;
|
||||
Torch::ValueTensorType resultType;
|
||||
if (binder.tensorOperands(operands, 3) ||
|
||||
binder.customOpNameStringAttr(
|
||||
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
|
||||
binder.customOpNameStringAttr(mode, "mode", "avg") ||
|
||||
binder.s64IntegerAttr(outHInt, "output_height", 1) ||
|
||||
binder.s64IntegerAttr(outWInt, "output_width", 1) ||
|
||||
binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) ||
|
||||
binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) ||
|
||||
binder.tensorResultType(resultType))
|
||||
return failure();
|
||||
Value input = operands[0];
|
||||
Value rois = operands[1];
|
||||
Value batchIndices = operands[2];
|
||||
|
||||
// the torchvision roi_pool op does not support these features:
|
||||
if (mode == "max" &&
|
||||
(coordTfMode != "half_pixel" || samplingRatioInt != 0))
|
||||
return rewriter.notifyMatchFailure(
|
||||
binder.op, "unsupported: roi max pooling without default "
|
||||
"coordTfMode and sampling_ratio");
|
||||
|
||||
Location loc = binder.getLoc();
|
||||
// concatenate the batchIndices to the rois to get rois as a num_roisx5
|
||||
// tensor. The batchIndices tensor is an int64 tensor, and needs to be
|
||||
// converted to float before concatenation.
|
||||
auto roisType = dyn_cast<Torch::ValueTensorType>(rois.getType());
|
||||
if (!roisType || !roisType.hasSizes())
|
||||
return failure();
|
||||
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
|
||||
binder.getLoc(), rewriter.getI64IntegerAttr(1));
|
||||
FailureOr<Value> unsqueezeIndices =
|
||||
Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim);
|
||||
if (failed(unsqueezeIndices))
|
||||
return failure();
|
||||
batchIndices = unsqueezeIndices.value();
|
||||
auto batchIndicesType =
|
||||
cast<Torch::ValueTensorType>(batchIndices.getType());
|
||||
Value dTypeInt =
|
||||
Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype());
|
||||
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
|
||||
Value cstFalse =
|
||||
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
|
||||
Value newBatchIndices = rewriter.create<Torch::AtenToDtypeOp>(
|
||||
loc,
|
||||
batchIndicesType.getWithSizesAndDtype(
|
||||
batchIndicesType.getOptionalSizes(),
|
||||
roisType.getOptionalDtype()),
|
||||
batchIndices, dTypeInt, cstFalse, cstFalse, none);
|
||||
SmallVector<int64_t> roiSizes(roisType.getSizes());
|
||||
roiSizes.back() = 5;
|
||||
auto catType = rewriter.getType<Torch::ValueTensorType>(
|
||||
roiSizes, roisType.getDtype());
|
||||
Type listElemType =
|
||||
roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
|
||||
/*optionalDtype=*/nullptr);
|
||||
Type listType = Torch::ListType::get(listElemType);
|
||||
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
|
||||
binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois});
|
||||
Value newRois =
|
||||
rewriter.create<Torch::AtenCatOp>(loc, catType, tensorList, cstDim);
|
||||
|
||||
// make constants from attributes
|
||||
Value cstSpatialScale = rewriter.create<Torch::ConstantFloatOp>(
|
||||
loc, rewriter.getF64FloatAttr(spatialScaleFloat));
|
||||
Value pooledHeight = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(outHInt));
|
||||
Value pooledWidth = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(outWInt));
|
||||
// this is for consistency with the default pytorch sampling ratio value
|
||||
samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt;
|
||||
Value samplingRatio = rewriter.create<Torch::ConstantIntOp>(
|
||||
loc, rewriter.getI64IntegerAttr(samplingRatioInt));
|
||||
bool aligned = coordTfMode == "half_pixel";
|
||||
Value cstAligned = rewriter.create<Torch::ConstantBoolOp>(loc, aligned);
|
||||
|
||||
if (mode == "avg") {
|
||||
rewriter.replaceOpWithNewOp<Torch::TorchvisionRoiAlignOp>(
|
||||
binder.op, resultType, input, newRois, cstSpatialScale,
|
||||
pooledHeight, pooledWidth, samplingRatio, cstAligned);
|
||||
return success();
|
||||
}
|
||||
// mode == "max"
|
||||
auto indicesType = resultType.getWithSizesAndDtype(
|
||||
resultType.getOptionalSizes(), batchIndicesType.getDtype());
|
||||
auto roiPool = rewriter.create<Torch::TorchvisionRoiPoolOp>(
|
||||
loc, TypeRange{resultType, indicesType}, input, newRois,
|
||||
cstSpatialScale, pooledHeight, pooledWidth);
|
||||
rewriter.replaceOp(binder.op, roiPool.getResult(0));
|
||||
return success();
|
||||
});
|
||||
patterns.onOp(
|
||||
"SpaceToDepth", 1,
|
||||
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
|
||||
|
|
|
@ -6256,6 +6256,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
|
|||
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
|
||||
" return %0 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list<int> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" return %2 : !torch.list<int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" return %0 : !torch.int\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
|
||||
" %int0 = torch.constant.int 0\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
|
||||
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
|
||||
" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
|
||||
" return %3 : !torch.tuple<list<int>, list<int>>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
|
||||
" %int4 = torch.constant.int 4\n"
|
||||
" %int1 = torch.constant.int 1\n"
|
||||
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
|
||||
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
|
||||
" return %1 : !torch.tuple<int, int>\n"
|
||||
" }\n"
|
||||
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
|
||||
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
|
||||
" %true = torch.constant.bool true\n"
|
||||
|
|
|
@ -8,6 +8,7 @@ import argparse
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from torch import device
|
||||
import torch.jit._shape_functions as upstream_shape_functions
|
||||
|
||||
|
@ -85,6 +86,20 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
|||
def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
|
||||
return upstream_shape_functions.unary(self)
|
||||
|
||||
|
||||
def torchvision〇roi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]:
|
||||
return [rois[0], input[1], pooled_height, pooled_width]
|
||||
|
||||
def torchvision〇roi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int:
|
||||
return input_rank_dtype[1]
|
||||
|
||||
def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]:
|
||||
output = [rois[0], input[1], pooled_height, pooled_width]
|
||||
return (output, output)
|
||||
|
||||
def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]:
|
||||
return (input_rank_dtype[1], torch.int64)
|
||||
|
||||
@check_shape_function([
|
||||
Invocation(TensorOfShape(2, 3, 4)), # Basic case.
|
||||
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.
|
||||
|
|
|
@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
|
|||
traits=["HasValueSemantics"],
|
||||
)
|
||||
|
||||
emit(
|
||||
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
|
||||
)
|
||||
emit(
|
||||
"torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)"
|
||||
)
|
||||
|
||||
|
||||
def dump_registered_ops(outfile: TextIO, registry: Registry):
|
||||
for _, v in sorted(registry.by_unique_key.items()):
|
||||
|
@ -1173,6 +1180,8 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
|
|||
|
||||
def main(args: argparse.Namespace):
|
||||
_maybe_import_op_extensions(args)
|
||||
import torchvision
|
||||
|
||||
registry = Registry.load()
|
||||
if args.debug_registry_dump:
|
||||
with open(args.debug_registry_dump, "w") as debug_registry_dump:
|
||||
|
|
|
@ -2207,6 +2207,37 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_roialign_avg
|
||||
func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[Dim:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
|
||||
// CHECK: %[[cst6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
|
||||
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
|
||||
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
|
||||
// CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]]
|
||||
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
|
||||
return %0 : !torch.vtensor<[30,2,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_roialign_max
|
||||
func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[Dim:.*]] = torch.constant.int 1
|
||||
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
|
||||
// CHECK: %[[cst6:.*]] = torch.constant.int 6
|
||||
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
|
||||
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
|
||||
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
|
||||
// CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]]
|
||||
// CHECK: return %[[Pool]]
|
||||
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
|
||||
return %0 : !torch.vtensor<[30,2,5,5],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @test_spacetodepth_example
|
||||
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
|
||||
// CHECK: %[[C0:.*]] = torch.constant.int 0
|
||||
|
|
Loading…
Reference in New Issue