[ONNX] Add basic support for RoiAlign (#3493)

This adds an onnx->torch conversion for onnx.RoiAlign into
torchvision.roi_align or torchvision.roi_pool, and adds those two
torchvision ops to torch-mlir.
pull/3495/head
zjgarvey 2024-06-25 11:02:45 -05:00 committed by GitHub
parent 02340408b7
commit e346c911f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 239 additions and 0 deletions

View File

@ -16660,3 +16660,60 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [
}]; }];
} }
def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$rois,
Torch_FloatType:$spatial_scale,
Torch_IntType:$pooled_height,
Torch_IntType:$pooled_width,
Torch_IntType:$sampling_ratio,
Torch_BoolType:$aligned
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 7, 1);
}
void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 7, 1);
}
}];
}
def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`";
let arguments = (ins
AnyTorchTensorType:$input,
AnyTorchTensorType:$rois,
Torch_FloatType:$spatial_scale,
Torch_IntType:$pooled_height,
Torch_IntType:$pooled_width
);
let results = (outs
AnyTorchOptionalTensorType:$result0,
AnyTorchOptionalTensorType:$result1
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 2);
}
void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 2);
}
}];
}

View File

@ -2953,6 +2953,104 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ(
/*Torch_BoolType:$antialias*/ cstFalse); /*Torch_BoolType:$antialias*/ cstFalse);
return success(); return success();
}); });
patterns.onOp(
"RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// operands = input, rois, batch_indices
SmallVector<Value> operands;
std::string coordTfMode, mode;
int64_t outHInt, outWInt, samplingRatioInt;
float spatialScaleFloat;
Torch::ValueTensorType resultType;
if (binder.tensorOperands(operands, 3) ||
binder.customOpNameStringAttr(
coordTfMode, "coordinate_transformation_mode", "half_pixel") ||
binder.customOpNameStringAttr(mode, "mode", "avg") ||
binder.s64IntegerAttr(outHInt, "output_height", 1) ||
binder.s64IntegerAttr(outWInt, "output_width", 1) ||
binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) ||
binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) ||
binder.tensorResultType(resultType))
return failure();
Value input = operands[0];
Value rois = operands[1];
Value batchIndices = operands[2];
// the torchvision roi_pool op does not support these features:
if (mode == "max" &&
(coordTfMode != "half_pixel" || samplingRatioInt != 0))
return rewriter.notifyMatchFailure(
binder.op, "unsupported: roi max pooling without default "
"coordTfMode and sampling_ratio");
Location loc = binder.getLoc();
// concatenate the batchIndices to the rois to get rois as a num_roisx5
// tensor. The batchIndices tensor is an int64 tensor, and needs to be
// converted to float before concatenation.
auto roisType = dyn_cast<Torch::ValueTensorType>(rois.getType());
if (!roisType || !roisType.hasSizes())
return failure();
Value cstDim = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
FailureOr<Value> unsqueezeIndices =
Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim);
if (failed(unsqueezeIndices))
return failure();
batchIndices = unsqueezeIndices.value();
auto batchIndicesType =
cast<Torch::ValueTensorType>(batchIndices.getType());
Value dTypeInt =
Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype());
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
Value newBatchIndices = rewriter.create<Torch::AtenToDtypeOp>(
loc,
batchIndicesType.getWithSizesAndDtype(
batchIndicesType.getOptionalSizes(),
roisType.getOptionalDtype()),
batchIndices, dTypeInt, cstFalse, cstFalse, none);
SmallVector<int64_t> roiSizes(roisType.getSizes());
roiSizes.back() = 5;
auto catType = rewriter.getType<Torch::ValueTensorType>(
roiSizes, roisType.getDtype());
Type listElemType =
roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt,
/*optionalDtype=*/nullptr);
Type listType = Torch::ListType::get(listElemType);
Value tensorList = rewriter.create<Torch::PrimListConstructOp>(
binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois});
Value newRois =
rewriter.create<Torch::AtenCatOp>(loc, catType, tensorList, cstDim);
// make constants from attributes
Value cstSpatialScale = rewriter.create<Torch::ConstantFloatOp>(
loc, rewriter.getF64FloatAttr(spatialScaleFloat));
Value pooledHeight = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(outHInt));
Value pooledWidth = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(outWInt));
// this is for consistency with the default pytorch sampling ratio value
samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt;
Value samplingRatio = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(samplingRatioInt));
bool aligned = coordTfMode == "half_pixel";
Value cstAligned = rewriter.create<Torch::ConstantBoolOp>(loc, aligned);
if (mode == "avg") {
rewriter.replaceOpWithNewOp<Torch::TorchvisionRoiAlignOp>(
binder.op, resultType, input, newRois, cstSpatialScale,
pooledHeight, pooledWidth, samplingRatio, cstAligned);
return success();
}
// mode == "max"
auto indicesType = resultType.getWithSizesAndDtype(
resultType.getOptionalSizes(), batchIndicesType.getDtype());
auto roiPool = rewriter.create<Torch::TorchvisionRoiPoolOp>(
loc, TypeRange{resultType, indicesType}, input, newRois,
cstSpatialScale, pooledHeight, pooledWidth);
rewriter.replaceOp(binder.op, roiPool.getResult(0));
return success();
});
patterns.onOp( patterns.onOp(
"SpaceToDepth", 1, "SpaceToDepth", 1,
[](OpBinder binder, ConversionPatternRewriter &rewriter) { [](OpBinder binder, ConversionPatternRewriter &rewriter) {

View File

@ -6256,6 +6256,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n" " return %0 : !torch.list<int>\n"
" }\n" " }\n"
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" return %0 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<list<int>, list<int>> {\n"
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list<int>, !torch.list<int> -> !torch.tuple<list<int>, list<int>>\n"
" return %3 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple<int, int>, !torch.int -> !torch.int\n"
" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n" " func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list<int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list<int> {\n"
" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" " %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n"
" %true = torch.constant.bool true\n" " %true = torch.constant.bool true\n"

View File

@ -8,6 +8,7 @@ import argparse
import os import os
import torch import torch
import torchvision
from torch import device from torch import device
import torch.jit._shape_functions as upstream_shape_functions import torch.jit._shape_functions as upstream_shape_functions
@ -85,6 +86,20 @@ def atentriu〡shape(self: List[int], diagonal: int = 0) -> List[int]:
def atentril〡shape(self: List[int], diagonal: int = 0) -> List[int]: def atentril〡shape(self: List[int], diagonal: int = 0) -> List[int]:
return upstream_shape_functions.unary(self) return upstream_shape_functions.unary(self)
def torchvisionroi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]:
return [rois[0], input[1], pooled_height, pooled_width]
def torchvisionroi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int:
return input_rank_dtype[1]
def torchvisionroi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]:
output = [rois[0], input[1], pooled_height, pooled_width]
return (output, output)
def torchvisionroi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]:
return (input_rank_dtype[1], torch.int64)
@check_shape_function([ @check_shape_function([
Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4)), # Basic case.
Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`.

View File

@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
traits=["HasValueSemantics"], traits=["HasValueSemantics"],
) )
emit(
"torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)"
)
emit(
"torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)"
)
def dump_registered_ops(outfile: TextIO, registry: Registry): def dump_registered_ops(outfile: TextIO, registry: Registry):
for _, v in sorted(registry.by_unique_key.items()): for _, v in sorted(registry.by_unique_key.items()):
@ -1173,6 +1180,8 @@ def _maybe_import_op_extensions(args: argparse.Namespace):
def main(args: argparse.Namespace): def main(args: argparse.Namespace):
_maybe_import_op_extensions(args) _maybe_import_op_extensions(args)
import torchvision
registry = Registry.load() registry = Registry.load()
if args.debug_registry_dump: if args.debug_registry_dump:
with open(args.debug_registry_dump, "w") as debug_registry_dump: with open(args.debug_registry_dump, "w") as debug_registry_dump:

View File

@ -2207,6 +2207,37 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve
// ----- // -----
// CHECK-LABEL: @test_roialign_avg
func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[Dim:.*]] = torch.constant.int 1
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
// CHECK: %[[cst6:.*]] = torch.constant.int 6
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
// CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]]
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
return %0 : !torch.vtensor<[30,2,5,5],f32>
}
// -----
// CHECK-LABEL: @test_roialign_max
func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[Dim:.*]] = torch.constant.int 1
// CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]]
// CHECK: %[[cst6:.*]] = torch.constant.int 6
// CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]]
// CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1
// CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]]
// CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]]
// CHECK: return %[[Pool]]
%0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32>
return %0 : !torch.vtensor<[30,2,5,5],f32>
}
// -----
// CHECK-LABEL: @test_spacetodepth_example // CHECK-LABEL: @test_spacetodepth_example
func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[C0:.*]] = torch.constant.int 0 // CHECK: %[[C0:.*]] = torch.constant.int 0