diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c351d845c..bab7131f7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16660,3 +16660,60 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ }]; } +def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width, + Torch_IntType:$sampling_ratio, + Torch_BoolType:$aligned + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiAlignOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void TorchvisionRoiAlignOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + +def Torch_TorchvisionRoiPoolOp : Torch_Op<"torchvision.roi_pool", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$rois, + Torch_FloatType:$spatial_scale, + Torch_IntType:$pooled_height, + Torch_IntType:$pooled_width + ); + let results = (outs + AnyTorchOptionalTensorType:$result0, + AnyTorchOptionalTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionRoiPoolOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 2); + } + void TorchvisionRoiPoolOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 2); + } + }]; +} + diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a6d05d7cc..58d8397ee 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2953,6 +2953,104 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*Torch_BoolType:$antialias*/ cstFalse); return success(); }); + patterns.onOp( + "RoiAlign", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // operands = input, rois, batch_indices + SmallVector operands; + std::string coordTfMode, mode; + int64_t outHInt, outWInt, samplingRatioInt; + float spatialScaleFloat; + Torch::ValueTensorType resultType; + if (binder.tensorOperands(operands, 3) || + binder.customOpNameStringAttr( + coordTfMode, "coordinate_transformation_mode", "half_pixel") || + binder.customOpNameStringAttr(mode, "mode", "avg") || + binder.s64IntegerAttr(outHInt, "output_height", 1) || + binder.s64IntegerAttr(outWInt, "output_width", 1) || + binder.s64IntegerAttr(samplingRatioInt, "sampling_ratio", 0) || + binder.f32FloatAttr(spatialScaleFloat, "spatial_scale", 1.0f) || + binder.tensorResultType(resultType)) + return failure(); + Value input = operands[0]; + Value rois = operands[1]; + Value batchIndices = operands[2]; + + // the torchvision roi_pool op does not support these features: + if (mode == "max" && + (coordTfMode != "half_pixel" || samplingRatioInt != 0)) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: roi max pooling without default " + "coordTfMode and sampling_ratio"); + + Location loc = binder.getLoc(); + // concatenate the batchIndices to the rois to get rois as a num_roisx5 + // tensor. The batchIndices tensor is an int64 tensor, and needs to be + // converted to float before concatenation. + auto roisType = dyn_cast(rois.getType()); + if (!roisType || !roisType.hasSizes()) + return failure(); + Value cstDim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + FailureOr unsqueezeIndices = + Torch::unsqueezeTensor(rewriter, binder.op, batchIndices, cstDim); + if (failed(unsqueezeIndices)) + return failure(); + batchIndices = unsqueezeIndices.value(); + auto batchIndicesType = + cast(batchIndices.getType()); + Value dTypeInt = + Torch::getDtypeIntValueForType(rewriter, loc, roisType.getDtype()); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value newBatchIndices = rewriter.create( + loc, + batchIndicesType.getWithSizesAndDtype( + batchIndicesType.getOptionalSizes(), + roisType.getOptionalDtype()), + batchIndices, dTypeInt, cstFalse, cstFalse, none); + SmallVector roiSizes(roisType.getSizes()); + roiSizes.back() = 5; + auto catType = rewriter.getType( + roiSizes, roisType.getDtype()); + Type listElemType = + roisType.getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, ValueRange{newBatchIndices, rois}); + Value newRois = + rewriter.create(loc, catType, tensorList, cstDim); + + // make constants from attributes + Value cstSpatialScale = rewriter.create( + loc, rewriter.getF64FloatAttr(spatialScaleFloat)); + Value pooledHeight = rewriter.create( + loc, rewriter.getI64IntegerAttr(outHInt)); + Value pooledWidth = rewriter.create( + loc, rewriter.getI64IntegerAttr(outWInt)); + // this is for consistency with the default pytorch sampling ratio value + samplingRatioInt = (samplingRatioInt == 0) ? -1 : samplingRatioInt; + Value samplingRatio = rewriter.create( + loc, rewriter.getI64IntegerAttr(samplingRatioInt)); + bool aligned = coordTfMode == "half_pixel"; + Value cstAligned = rewriter.create(loc, aligned); + + if (mode == "avg") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, newRois, cstSpatialScale, + pooledHeight, pooledWidth, samplingRatio, cstAligned); + return success(); + } + // mode == "max" + auto indicesType = resultType.getWithSizesAndDtype( + resultType.getOptionalSizes(), batchIndicesType.getDtype()); + auto roiPool = rewriter.create( + loc, TypeRange{resultType, indicesType}, input, newRois, + cstSpatialScale, pooledHeight, pooledWidth); + rewriter.replaceOp(binder.op, roiPool.getResult(0)); + return success(); + }); patterns.onOp( "SpaceToDepth", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 537d3b619..69d48fa3c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6256,6 +6256,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_align\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_align\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.roi_pool\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple, list> {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.prim.ListConstruct %0, %1, %arg3, %arg4 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %2, %2 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %3 : !torch.tuple, list>\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.roi_pool\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.float, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" %1 = torch.prim.TupleConstruct %0, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" " %true = torch.constant.bool true\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index e77a1978b..97fe12255 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -8,6 +8,7 @@ import argparse import os import torch +import torchvision from torch import device import torch.jit._shape_functions as upstream_shape_functions @@ -85,6 +86,20 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) + +def torchvision〇roi_align〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> List[int]: + return [rois[0], input[1], pooled_height, pooled_width] + +def torchvision〇roi_align〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int, sampling_ratio: int, aligned: bool) -> int: + return input_rank_dtype[1] + +def torchvision〇roi_pool〡shape(input: List[int], rois: List[int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[List[int], List[int]]: + output = [rois[0], input[1], pooled_height, pooled_width] + return (output, output) + +def torchvision〇roi_pool〡dtype(input_rank_dtype: Tuple[int, int], rois_rank_dtype: Tuple[int, int], spatial_scale: float, pooled_height: int, pooled_width: int) -> Tuple[int, int]: + return (input_rank_dtype[1], torch.int64) + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index b21362f7c..401e7bef2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): traits=["HasValueSemantics"], ) + emit( + "torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)" + ) + emit( + "torchvision::roi_pool : (Tensor, Tensor, float, int, int) -> (Tensor, Tensor)" + ) + def dump_registered_ops(outfile: TextIO, registry: Registry): for _, v in sorted(registry.by_unique_key.items()): @@ -1173,6 +1180,8 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args: argparse.Namespace): _maybe_import_op_extensions(args) + import torchvision + registry = Registry.load() if args.debug_registry_dump: with open(args.debug_registry_dump, "w") as debug_registry_dump: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 445d54c86..d611823f9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -2207,6 +2207,37 @@ f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_ve // ----- +// CHECK-LABEL: @test_roialign_avg + func.func @test_roialign_avg(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Align:.*]] = torch.torchvision.roi_align %arg0, %[[Cat]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "output_half_pixel", torch.onnx.mode = "avg", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + +// CHECK-LABEL: @test_roialign_max + func.func @test_roialign_max(%arg0: !torch.vtensor<[6,2,100,100],f32>, %arg1: !torch.vtensor<[30,4],f32>, %arg2: !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[Dim:.*]] = torch.constant.int 1 + // CHECK: %[[Unsqueeze:.*]] = torch.aten.unsqueeze %arg2, %[[Dim]] + // CHECK: %[[cst6:.*]] = torch.constant.int 6 + // CHECK: %[[Cast:.*]] = torch.aten.to.dtype %[[Unsqueeze]], %[[cst6]] + // CHECK: %[[List:.*]] = torch.prim.ListConstruct %[[Cast]], %arg1 + // CHECK: %[[Cat:.*]] = torch.aten.cat %[[List]], %[[Dim]] + // CHECK: %[[Pool:.*]], %[[Indices:.*]] = torch.torchvision.roi_pool %arg0, %[[Cat]] + // CHECK: return %[[Pool]] + %0 = torch.operator "onnx.RoiAlign"(%arg0, %arg1, %arg2) {torch.onnx.coordinate_transformation_mode = "half_pixel", torch.onnx.mode = "max", torch.onnx.output_height = 5 : si64, torch.onnx.output_width = 5 : si64, torch.onnx.sampling_ratio = 0 : si64, torch.onnx.spatial_scale = 1.000000e+00 : f32} : (!torch.vtensor<[6,2,100,100],f32>, !torch.vtensor<[30,4],f32>, !torch.vtensor<[30],si64>) -> !torch.vtensor<[30,2,5,5],f32> + return %0 : !torch.vtensor<[30,2,5,5],f32> + } + +// ----- + // CHECK-LABEL: @test_spacetodepth_example func.func @test_spacetodepth_example(%arg0: !torch.vtensor<[1,1,4,6],f32>) -> !torch.vtensor<[1,4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0