diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ec14fd287..684d33c92 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2556,7 +2556,7 @@ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorListType:$indices, + AnyTorchListOfOptionalTensorType:$indices, AnyTorchTensorType:$values, Torch_BoolType:$accumulate ); @@ -2581,7 +2581,7 @@ def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [ let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorListType:$indices, + AnyTorchListOfOptionalTensorType:$indices, AnyTorchTensorType:$values, Torch_BoolType:$accumulate ); @@ -2709,9 +2709,9 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AnyTorchTensorType:$input, AnyTorchTensorType:$weight, AnyTorchOptionalTensorType:$bias, - TorchIntListType:$stride, - TorchIntListType:$padding, - TorchIntListType:$dilation, + ListOfTorchIntType:$stride, + ListOfTorchIntType:$padding, + ListOfTorchIntType:$dilation, Torch_IntType:$groups ); let results = (outs @@ -2799,7 +2799,7 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - TorchIntListType:$normalized_shape, + ListOfTorchIntType:$normalized_shape, AnyTorchOptionalTensorType:$weight, AnyTorchOptionalTensorType:$bias, Torch_FloatType:$eps, @@ -2827,7 +2827,7 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - TorchIntListType:$normalized_shape, + ListOfTorchIntType:$normalized_shape, AnyTorchOptionalTensorType:$weight, AnyTorchOptionalTensorType:$bias, Torch_FloatType:$eps @@ -2856,10 +2856,10 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$kernel_size, - TorchIntListType:$stride, - TorchIntListType:$padding, - TorchIntListType:$dilation, + ListOfTorchIntType:$kernel_size, + ListOfTorchIntType:$stride, + ListOfTorchIntType:$padding, + ListOfTorchIntType:$dilation, Torch_BoolType:$ceil_mode ); let results = (outs @@ -2959,7 +2959,7 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$output_size + ListOfTorchIntType:$output_size ); let results = (outs AnyTorchTensorType:$result @@ -3034,7 +3034,7 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$dims + ListOfTorchIntType:$dims ); let results = (outs AnyTorchTensorType:$result @@ -3131,7 +3131,7 @@ def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [ let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$dim, + ListOfTorchIntType:$dim, Torch_BoolType:$keepdim ); let results = (outs @@ -3156,7 +3156,7 @@ def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [ let summary = "Generated op for `aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$dim, + ListOfTorchIntType:$dim, Torch_BoolType:$keepdim, TorchOptionalIntType:$dtype ); @@ -3408,7 +3408,7 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$pad, + ListOfTorchIntType:$pad, AnyTorchScalarType:$value ); let results = (outs @@ -3553,7 +3553,7 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [ AnyTorchTensorType:$self ); let results = (outs - TorchIntListType:$result + ListOfTorchIntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ @@ -3597,7 +3597,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ ]> { let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -3625,7 +3625,7 @@ def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -3652,7 +3652,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ ]> { let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -3701,7 +3701,7 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -4212,7 +4212,7 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [ let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -4295,7 +4295,7 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, TorchOptionalDeviceType:$device, @@ -4323,7 +4323,7 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, + ListOfTorchIntType:$size, Torch_BoolType:$implicit ); let results = (outs @@ -4370,7 +4370,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size + ListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -4394,7 +4394,7 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [ let summary = "Generated op for `aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorListType:$indices + AnyTorchListOfOptionalTensorType:$indices ); let results = (outs AnyTorchTensorType:$result @@ -4441,7 +4441,7 @@ def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [ let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorListType:$indices, + AnyTorchListOfOptionalTensorType:$indices, AnyTorchTensorType:$values, Torch_BoolType:$accumulate, Torch_BoolType:$unsafe @@ -4538,7 +4538,7 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$repeats + ListOfTorchIntType:$repeats ); let results = (outs AnyTorchTensorType:$result @@ -4561,7 +4561,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$shape + ListOfTorchIntType:$shape ); let results = (outs AnyTorchTensorType:$result @@ -4584,8 +4584,8 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, - TorchIntListType:$stride + ListOfTorchIntType:$size, + ListOfTorchIntType:$stride ); let results = (outs AnyTorchTensorType:$result @@ -4607,7 +4607,7 @@ def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size, + ListOfTorchIntType:$size, TorchOptionalIntType:$memory_format ); let results = (outs @@ -4680,7 +4680,7 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [ ]> { let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorListType:$tensors, + AnyTorchListOfTensorType:$tensors, Torch_IntType:$dim ); let results = (outs @@ -4729,7 +4729,7 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$dim, + ListOfTorchIntType:$dim, Torch_BoolType:$keepdim, TorchOptionalIntType:$dtype ); @@ -4906,7 +4906,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size + ListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -4931,7 +4931,7 @@ def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchIntListType:$size + ListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -5295,7 +5295,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [ ]> { let summary = "Generated op for `aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - TorchIntListType:$size, + ListOfTorchIntType:$size, AnyTorchScalarType:$fill_value, TorchOptionalIntType:$dtype, TorchOptionalIntType:$layout, @@ -5425,7 +5425,7 @@ def Torch_AtenKeysStrOp : Torch_Op<"aten.keys.str", [ Torch_DictType:$self ); let results = (outs - TorchStringListType:$result + ListOfTorchStringType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ @@ -5490,7 +5490,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [ ]> { let summary = "Generated op for `aten::cat : (Tensor[], int) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorListType:$tensors, + AnyTorchListOfTensorType:$tensors, Torch_IntType:$dim ); let results = (outs @@ -5560,8 +5560,8 @@ def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ ]> { let summary = "Generated op for `aten::eq.int_list : (int[], int[]) -> (bool)`"; let arguments = (ins - TorchIntListType:$a, - TorchIntListType:$b + ListOfTorchIntType:$a, + ListOfTorchIntType:$b ); let results = (outs Torch_BoolType:$result @@ -5656,8 +5656,8 @@ def Torch_AtenNeIntListOp : Torch_Op<"aten.ne.int_list", [ ]> { let summary = "Generated op for `aten::ne.int_list : (int[], int[]) -> (bool)`"; let arguments = (ins - TorchIntListType:$a, - TorchIntListType:$b + ListOfTorchIntType:$a, + ListOfTorchIntType:$b ); let results = (outs Torch_BoolType:$result @@ -5767,7 +5767,7 @@ def Torch_AtenJoinOp : Torch_Op<"aten.join", [ let summary = "Generated op for `aten::join : (str, str[]) -> (str)`"; let arguments = (ins Torch_StringType:$self, - TorchStringListType:$values + ListOfTorchStringType:$values ); let results = (outs Torch_StringType:$result @@ -6905,7 +6905,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ ]> { let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`"; let arguments = (ins - TorchIntListType:$self + ListOfTorchIntType:$self ); let results = (outs Torch_IntType:$result @@ -6953,7 +6953,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ ]> { let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`"; let arguments = (ins - TorchIntListType:$self + ListOfTorchIntType:$self ); let results = (outs Torch_IntType:$result diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index 63061074b..f57a6b006 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -1042,7 +1042,7 @@ def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl" let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalTensorListType:$indices, + AnyTorchListOfOptionalTensorType:$indices, AnyTorchTensorType:$values, Torch_BoolType:$accumulate, Torch_BoolType:$unsafe @@ -1171,7 +1171,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes", }]; let arguments = (ins - Variadic:$results + Variadic:$results ); let results = (outs); diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 2eb6c5a43..c643eb646 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -425,14 +425,15 @@ class ListOf allowedTypes, string descr> : "$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()", descr, "::mlir::torch::Torch::ListType">; -def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; -def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">; -def TorchStringListType : ListOf<[Torch_StringType], "Str list type (str[])">; -def AnyTorchTensorListType: +def ListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; +def ListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">; +def ListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">; +def AnyTorchListOfTensorType: ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">; -def AnyTorchOptionalTensorListType : +def AnyTorchListOfOptionalTensorType : ListOf<[AnyTorchOptionalTensorType], "Any optional tensor list type (Tensor?[])">; +def OptionalListOfTorchIntType : OptionalOf; // Note: TorchScript does not consider !torch.bool to be a Scalar. def AnyTorchScalarType : diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index a0be58bc4..680341ea6 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -20,15 +20,16 @@ from .registry import Registry, JitOperator TORCH_TYPE_TO_ODS_TYPE = { "Tensor": "AnyTorchTensorType", "Tensor?": "AnyTorchOptionalTensorType", - "Tensor?[]": "AnyTorchOptionalTensorListType", - "Tensor[]": "AnyTorchTensorListType", + "Tensor?[]": "AnyTorchListOfOptionalTensorType", + "Tensor[]": "AnyTorchListOfTensorType", "Scalar": "AnyTorchScalarType", "Scalar?": "AnyTorchOptionalScalarType", "int": "Torch_IntType", - "int[]": "TorchIntListType", + "int[]": "ListOfTorchIntType", "int?": "TorchOptionalIntType", + "int[]?": "OptionalListOfTorchIntType", "bool": "Torch_BoolType", - "bool[]": "TorchBoolListType", + "bool[]": "ListOfTorchBoolType", "bool?": "TorchOptionalBoolType", "float": "Torch_FloatType", "t[]": "AnyTorchListType", @@ -42,7 +43,7 @@ TORCH_TYPE_TO_ODS_TYPE = { "Generator?": "TorchOptionalGeneratorType", "str": "Torch_StringType", "str?": "TorchOptionalStringType", - "str[]": "TorchStringListType", + "str[]": "ListOfTorchStringType", "Dict": "Torch_DictType", "__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType", }