diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index bab7131f7..4b2ba6def 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -16660,6 +16660,42 @@ def Torch_QuantizedLinearOp : Torch_Op<"quantized.linear", [ }]; } +def Torch_TorchvisionDeformConv2dOp : Torch_Op<"torchvision.deform_conv2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$offset, + AnyTorchTensorType:$mask, + AnyTorchTensorType:$bias, + Torch_IntType:$stride_h, + Torch_IntType:$stride_w, + Torch_IntType:$pad_h, + Torch_IntType:$pad_w, + Torch_IntType:$dilation_h, + Torch_IntType:$dilation_w, + Torch_IntType:$groups, + Torch_IntType:$offset_groups, + Torch_BoolType:$use_mask + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult TorchvisionDeformConv2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 14, 1); + } + void TorchvisionDeformConv2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 14, 1); + } + }]; +} + def Torch_TorchvisionRoiAlignOp : Torch_Op<"torchvision.roi_align", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 6932908c0..c89452ad6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1837,6 +1837,141 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, transposedInput, reshapeSizesList); return success(); }); + patterns.onOp( + "DeformConv", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + auto loc = binder.getLoc(); + + // get operands + llvm::SmallVector operands; + Torch::ValueTensorType resultType; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType)) + return failure(); + if (operands.size() < 3 || operands.size() > 5) + return failure(); + auto inputType = + dyn_cast(operands[0].getType()); + if (!inputType || !inputType.hasSizes() || + inputType.getSizes().size() != 4) + return rewriter.notifyMatchFailure( + binder.op, "Unsupported: DeformConv with input rank != 4"); + unsigned rank = inputType.getSizes().size(); + auto weightType = + dyn_cast(operands[1].getType()); + if (!weightType || !weightType.hasSizes()) + return failure(); + auto offsetType = + dyn_cast(operands[2].getType()); + if (!offsetType || !offsetType.hasSizes()) + return failure(); + + // get attributes + SmallVector dilations, kernelShape, pads, strides; + SmallVector defaultDilations(rank - 2, 0); + SmallVector defaultPads(2 * (rank - 2), 0); + SmallVector defaultStrides(rank - 2, 1); + int64_t group, offsetGroup; + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations) || + binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {}) || + binder.s64IntegerArrayAttr(pads, "pads", defaultPads) || + binder.s64IntegerArrayAttr(strides, "strides", defaultStrides) || + binder.s64IntegerAttr(group, "group", 1) || + binder.s64IntegerAttr(offsetGroup, "offset_group", 1)) + return failure(); + + for (unsigned i = 0; i < rank - 2; i++) { + if (pads[i] != pads[rank + i - 2]) + return rewriter.notifyMatchFailure( + binder.op, "unsupported: asymmetric padding"); + } + + // Identify and assign names to operands + Value input, weight, offset, bias, mask; + bool useMask = false; + input = operands[0]; + weight = operands[1]; + offset = operands[2]; + if (operands.size() == 4) { + auto unknownOpdRank = Torch::getTensorRank(operands[3]); + if (!unknownOpdRank) + return failure(); + if (*unknownOpdRank == 1) + bias = operands[3]; + else if (*unknownOpdRank == rank) { + mask = operands[3]; + useMask = true; + } else + llvm_unreachable("onnx.DeformConv: optional 4th operand of " + "unexpected rank encountered"); + } + if (operands.size() == 5) { + bias = operands[3]; + mask = operands[4]; + useMask = true; + } + + // assign default operand values if necessary + ArrayRef weightSizes = weightType.getSizes(); + ArrayRef offsetSizes = offsetType.getSizes(); + if (!bias) { + int64_t outputChannels = weightSizes[0]; + SmallVector biasShape(1, outputChannels); + Value biasShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, biasShape); + Value cstZero = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 0.0f, inputType.getDtype()); + bias = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + biasShape, inputType.getDtype()), + cstZero, biasShapeList); + } + if (!mask) { + int64_t batchSize = inputType.getSizes()[0]; + int64_t kernelHeight = weightSizes[2]; + int64_t kernelWidth = weightSizes[3]; + int64_t outputHeight = offsetSizes[2]; + int64_t outputWidth = offsetSizes[3]; + int64_t maskDimOne = offsetGroup * kernelHeight * kernelWidth; + SmallVector maskShape( + {batchSize, maskDimOne, outputHeight, outputWidth}); + Value cstOne = Torch::getConstantWithGivenDtypeAndValue( + rewriter, loc, 1.0f, inputType.getDtype()); + Value maskShapeList = mlir::torch::onnx_c::createConstantIntList( + binder, rewriter, maskShape); + mask = + Torch::createInitTensor(rewriter, loc, + rewriter.getType( + maskShape, inputType.getDtype()), + cstOne, maskShapeList); + } + + // get attributes as constant values + SmallVector dilationValues, padValues, strideValues; + for (auto i : dilations) + dilationValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : pads) + padValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + for (auto i : strides) + strideValues.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + Value groupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(group)); + Value offsetGroupValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(offsetGroup)); + Value useMaskValue = rewriter.create( + loc, rewriter.getBoolAttr(useMask)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, offset, mask, bias, + strideValues[0], strideValues[1], padValues[0], padValues[1], + dilationValues[0], dilationValues[1], groupValue, offsetGroupValue, + useMaskValue); + return success(); + }); patterns.onOp( "DequantizeLinear", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 69d48fa3c..e94d3bd7c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9492,6 +9492,22 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int9, %int10 : (!torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.torchvision.deform_conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg2, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.torchvision.deform_conv2d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.tuple, %arg4: !torch.tuple, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.bool) -> !torch.int {\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.TupleIndex %arg0, %int1 : !torch.tuple, !torch.int -> !torch.int\n" +" return %0 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fb997435f..35a34e2b1 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -29,6 +29,9 @@ LINALG_XFAIL_SET = COMMON_TORCH_MLIR_LOWERING_XFAILS | { "InterpolateDynamicModule_scales_recompute_bilinear", "ElementwiseFloatTensorGtIntTensorModule_basic", "AtenIntMM_basic", + # unimplemented lowering torch -> linalg for torchvision.deform_conv2d + # this is added to check the torch.onnx.export -> import_onnx -> torch path + "DeformConv2D_basic", } LINALG_CRASHING_SET = { @@ -383,6 +386,7 @@ FX_IMPORTER_XFAIL_SET = { "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", "ElementwiseAddScalar_NumToTensorFloat_Module_basic", @@ -554,6 +558,7 @@ FX_IMPORTER_STABLEHLO_XFAIL_SET = { "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -2357,19 +2362,12 @@ ONNX_XFAIL_SET = { "DivIntModule_basic", "ElementwiseAcoshIntModule_basic", "ElementwiseAcoshModule_basic", - "ElementwiseAndScalarModule_basic", - "ElementwiseAndScalarStaticShapeModule_basic", "ElementwiseAsinhIntModule_basic", "ElementwiseAsinhModule_basic", "ElementwiseAtanhIntModule_basic", "ElementwiseAtanhModule_basic", "ElementwiseAtenIsneginfOpModule_basic", "ElementwiseAtenIsposinfOpModule_basic", - "ElementwiseBitwiseAndModule_basic", - "ElementwiseBitwiseAndScalarInt32Module_basic", - "ElementwiseBitwiseAndScalarInt64Module_basic", - "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt32Module_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrModule_basic", @@ -2710,6 +2708,8 @@ ONNX_XFAIL_SET = { "IndexPutHackedTwin3DIntNonAccumulateModule_basic", # RuntimeError: unsupported input type: Device "PrimsIotaModule_basic", + # unimplemented torchvision.deform_conv2d torch->linalg + "DeformConv2D_basic", # Error: 'aten::renorm' to ONNX opset version 17 is not supported. "RenormModuleFloat16_basic", "RenormModuleFloat32NegativeDim_basic", @@ -2759,6 +2759,14 @@ if torch_version_for_comparison() < version.parse("2.4.0.dev"): "ElementwiseBitwiseLeftShiftInt32Module_basic", "ElementwiseBitwiseLeftShiftInt64Module_basic", "ElementwiseBitwiseLeftShiftInt8Module_basic", + # bitwise and support has been added in torch nightly + "ElementwiseAndScalarModule_basic", + "ElementwiseAndScalarStaticShapeModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", } if torch_version_for_comparison() < version.parse("2.4.0.dev"): @@ -2930,6 +2938,7 @@ FX_IMPORTER_TOSA_XFAIL_SET = { "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -3724,6 +3733,7 @@ ONNX_TOSA_XFAIL_SET = { "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "DeformConv2D_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 97fe12255..1f70a42ce 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -8,7 +8,6 @@ import argparse import os import torch -import torchvision from torch import device import torch.jit._shape_functions as upstream_shape_functions @@ -1639,6 +1638,12 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: assert False, "Unsupported dtype" +def torchvision〇deform_conv2d〡shape(input: List[int], weight: List[int], offset: List[int], mask: List[int], bias: List[int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> List[int]: + return [input[0], weight[0], offset[2], offset[3]] + +def torchvision〇deform_conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], offset_rank_dtype: Tuple[int, int], mask_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], stride_h: int, stride_w: int, pad_h: int, pad_w: int, dilation_h: int, dilation_w: int, groups: int, offset_groups: int, use_mask: bool) -> int: + return input_rank_dtype[1] + def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) @@ -5117,6 +5122,9 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry + import torchvision + asm = generate_library(globals()) # We're about to put quotes around the string, so escape the `"` characters. asm = asm.replace("\"", "\\\"") diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 401e7bef2..7c3f79ef4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1155,6 +1155,13 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry): traits=["HasValueSemantics"], ) + # ========================================================================== + # `torchvision::` namespace. + # ========================================================================== + + emit( + "torchvision::deform_conv2d : (Tensor, Tensor, Tensor, Tensor, Tensor, int, int, int, int, int, int, int, int, bool) -> (Tensor)" + ) emit( "torchvision::roi_align : (Tensor, Tensor, float, int, int, int, bool) -> (Tensor)" ) @@ -1180,6 +1187,7 @@ def _maybe_import_op_extensions(args: argparse.Namespace): def main(args: argparse.Namespace): _maybe_import_op_extensions(args) + # importing torchvision will register torchvision ops with the JITOperatorRegistry import torchvision registry = Registry.load() diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index fb9b2712d..fc0d488b4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -9,6 +9,7 @@ from typing import Any import io import onnx import torch +from torch.onnx._constants import ONNX_TORCHSCRIPT_EXPORTER_MAX_OPSET as max_opset_ver import torch_mlir from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -78,7 +79,12 @@ def convert_onnx(model, inputs): examples = tuple(examples) torch.onnx.export( - model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors + model, + examples, + buffer, + input_names=input_names, + dynamic_axes=dynamic_tensors, + opset_version=max_opset_ver, ) buffer = buffer.getvalue() return import_onnx(buffer) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index af8bea091..2e00e2079 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -1256,3 +1256,90 @@ def ConvTranspose2DQInt8_basic(module, tu: TestUtils): tu.randint(Cin, Cout, Hker, Wker, low=-128, high=127).to(torch.int8), torch.rand(Cout), ) + + +# torchvision.deform_conv2d + +import torchvision + +# This section defines a torch->onnx path for this torchvision op so we can test the onnx paths e2e. + +# Create symbolic function +from torch.onnx.symbolic_helper import parse_args, _get_tensor_sizes + + +@parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b") +def symbolic_deform_conv2d_forward( + g, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask, +): + args = [input, weight, offset, bias] + if use_mask: + args.append(mask) + weight_size = _get_tensor_sizes(weight) + kwargs = { + "dilations_i": [dilation_h, dilation_w], + "group_i": groups, + "kernel_shape_i": weight_size[2:], + "offset_group_i": offset_groups, + # NB: ONNX supports asymmetric padding, whereas PyTorch supports only + # symmetric padding + "pads_i": [pad_h, pad_w, pad_h, pad_w], + "strides_i": [stride_h, stride_w], + } + return g.op("DeformConv", *args, **kwargs) + + +# Register symbolic function +from torch.onnx import register_custom_op_symbolic + +register_custom_op_symbolic( + "torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 19 +) + +N = 1 +Cin = 1 +Hin = 7 +Win = 6 +Cout = 1 +Hker = 2 +Wker = 2 +offset_groups = 1 +Hout = 6 +Wout = 5 +offset_dim1 = 2 * offset_groups * Hker * Wker + + +class DeformableConvModule(torch.nn.Module): + @export + @annotate_args( + [ + None, + ([N, Cin, Hin, Win], torch.float32, True), + ([N, offset_dim1, Hout, Wout], torch.float32, True), + ([Cout, Cin, Hker, Wker], torch.float32, True), + ] + ) + def forward(self, input, offset, weight): + return torchvision.ops.deform_conv2d(input, offset, weight) + + +@register_test_case(module_factory=lambda: DeformableConvModule()) +def DeformConv2D_basic(module, tu: TestUtils): + input = tu.rand(N, Cin, Hin, Win) + offset = tu.rand(N, offset_dim1, Hout, Wout) + weight = tu.rand(Cout, Cin, Hker, Wker) + module.forward(input, offset, weight) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 74793852d..4b03fccee 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -735,6 +735,19 @@ func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4 // ----- +// CHECK-LABEL: @test_deform_conv +func.func @test_deform_conv(%arg0: !torch.vtensor<[1,1,7,6],f32>, %arg1: !torch.vtensor<[1,8,6,5],f32>, %arg2: !torch.vtensor<[1,1,2,2],f32>, %arg3: !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "2.4.0"} { + // CHECK: %[[cstOne:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[mask:.*]] = torch.aten.full %[[sizeList:.*]], %[[cstOne]] + // CHECK-SAME: -> !torch.vtensor<[1,4,6,5],f32> + // CHECK: torch.torchvision.deform_conv2d %arg0, %arg2, %arg1, %[[mask]], %arg3 + // CHECK-SAME: : !torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1,4,6,5],f32>, !torch.vtensor<[1],f32>, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[1,1,6,5],f32> + %1 = torch.operator "onnx.DeformConv"(%arg0, %arg2, %arg1, %arg3) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.offset_group = 1 : si64, torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [1 : si64, 1 : si64]} : (!torch.vtensor<[1,1,7,6],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.vtensor<[1,8,6,5],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[1,1,6,5],f32> + return %1 : !torch.vtensor<[1,1,6,5],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_si8 func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32>