diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index e1bdfe4d0..77fcaa9b2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2387,7 +2387,7 @@ def Torch_AtenUniform_Op : Torch_Op<"aten.uniform_", [ AnyTorchTensorType:$self, Torch_FloatType:$from, Torch_FloatType:$to, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -2411,11 +2411,11 @@ def Torch_AtenRandLikeOp : Torch_Op<"aten.rand_like", [ let summary = "Generated op for `aten::rand_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -2439,7 +2439,7 @@ def Torch_AtenBernoulliOp : Torch_Op<"aten.bernoulli", [ let summary = "Generated op for `aten::bernoulli : (Tensor, Generator?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -2462,7 +2462,7 @@ def Torch_AtenBernoulli_FloatOp : Torch_Op<"aten.bernoulli_.float", [ let arguments = (ins AnyTorchTensorType:$self, Torch_FloatType:$p, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -2485,7 +2485,7 @@ def Torch_AtenBernoulli_TensorOp : Torch_Op<"aten.bernoulli_.Tensor", [ let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$p, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -2709,9 +2709,9 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AnyTorchTensorType:$input, AnyTorchTensorType:$weight, AnyTorchOptionalTensorType:$bias, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_IntType:$groups ); let results = (outs @@ -2738,11 +2738,11 @@ def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ AnyTorchTensorType:$input, AnyTorchTensorType:$weight, AnyTorchOptionalTensorType:$bias, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_BoolType:$transposed, - ListOfTorchIntType:$output_padding, + AnyTorchListOfTorchIntType:$output_padding, Torch_IntType:$groups ); let results = (outs @@ -2769,11 +2769,11 @@ def Torch_AtenConvolutionOverrideableOp : Torch_Op<"aten.convolution_overrideabl AnyTorchTensorType:$input, AnyTorchTensorType:$weight, AnyTorchOptionalTensorType:$bias, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_BoolType:$transposed, - ListOfTorchIntType:$output_padding, + AnyTorchListOfTorchIntType:$output_padding, Torch_IntType:$groups ); let results = (outs @@ -2861,7 +2861,7 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - ListOfTorchIntType:$normalized_shape, + AnyTorchListOfTorchIntType:$normalized_shape, AnyTorchOptionalTensorType:$weight, AnyTorchOptionalTensorType:$bias, Torch_FloatType:$eps, @@ -2889,7 +2889,7 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [ let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$input, - ListOfTorchIntType:$normalized_shape, + AnyTorchListOfTorchIntType:$normalized_shape, AnyTorchOptionalTensorType:$weight, AnyTorchOptionalTensorType:$bias, Torch_FloatType:$eps @@ -2918,10 +2918,10 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [ let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$kernel_size, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_BoolType:$ceil_mode ); let results = (outs @@ -2946,10 +2946,10 @@ def Torch_AtenMaxPool2dWithIndicesOp : Torch_Op<"aten.max_pool2d_with_indices", let summary = "Generated op for `aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$kernel_size, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_BoolType:$ceil_mode ); let results = (outs @@ -2976,10 +2976,10 @@ def Torch_AtenMaxPool2dWithIndicesBackwardOp : Torch_Op<"aten.max_pool2d_with_in let arguments = (ins AnyTorchTensorType:$grad_output, AnyTorchTensorType:$self, - ListOfTorchIntType:$kernel_size, - ListOfTorchIntType:$stride, - ListOfTorchIntType:$padding, - ListOfTorchIntType:$dilation, + AnyTorchListOfTorchIntType:$kernel_size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, Torch_BoolType:$ceil_mode, AnyTorchTensorType:$indices ); @@ -3006,7 +3006,7 @@ def Torch_AtenSoftmaxIntOp : Torch_Op<"aten.softmax.int", [ let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -3031,7 +3031,7 @@ def Torch_AtenLogSoftmaxIntOp : Torch_Op<"aten.log_softmax.int", [ let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -3080,7 +3080,7 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [ let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$output_size + AnyTorchListOfTorchIntType:$output_size ); let results = (outs AnyTorchTensorType:$result @@ -3155,7 +3155,7 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [ let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$dims + AnyTorchListOfTorchIntType:$dims ); let results = (outs AnyTorchTensorType:$result @@ -3204,7 +3204,7 @@ def Torch_AtenCumsumOp : Torch_Op<"aten.cumsum", [ let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -3252,7 +3252,7 @@ def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [ let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$dim, + AnyTorchListOfTorchIntType:$dim, Torch_BoolType:$keepdim ); let results = (outs @@ -3277,9 +3277,9 @@ def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [ let summary = "Generated op for `aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$dim, + AnyTorchListOfTorchIntType:$dim, Torch_BoolType:$keepdim, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -3375,7 +3375,7 @@ def Torch_AtenMeanOp : Torch_Op<"aten.mean", [ let summary = "Generated op for `aten::mean : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -3529,7 +3529,7 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$pad, + AnyTorchListOfTorchIntType:$pad, AnyTorchScalarType:$value ); let results = (outs @@ -3674,7 +3674,7 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [ AnyTorchTensorType:$self ); let results = (outs - ListOfTorchIntType:$result + AnyTorchListOfTorchIntType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ @@ -3718,11 +3718,11 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ ]> { let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -3746,11 +3746,11 @@ def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -3773,11 +3773,11 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ ]> { let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -3822,11 +3822,11 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -3850,8 +3850,8 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ let summary = "Generated op for `aten::tensor : (t[], int?, Device?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchListType:$data, - TorchOptionalIntType:$dtype, - TorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, Torch_BoolType:$requires_grad ); let results = (outs @@ -3876,8 +3876,8 @@ def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ let summary = "Generated op for `aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins Torch_BoolType:$t, - TorchOptionalIntType:$dtype, - TorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, Torch_BoolType:$requires_grad ); let results = (outs @@ -3902,8 +3902,8 @@ def Torch_AtenTensorIntOp : Torch_Op<"aten.tensor.int", [ let summary = "Generated op for `aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins Torch_IntType:$t, - TorchOptionalIntType:$dtype, - TorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, Torch_BoolType:$requires_grad ); let results = (outs @@ -4022,10 +4022,10 @@ def Torch_AtenArangeOp : Torch_Op<"aten.arange", [ let summary = "Generated op for `aten::arange : (Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchScalarType:$end, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -4050,10 +4050,10 @@ def Torch_AtenArangeStartOp : Torch_Op<"aten.arange.start", [ let arguments = (ins AnyTorchScalarType:$start, AnyTorchScalarType:$end, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -4079,10 +4079,10 @@ def Torch_AtenArangeStartStepOp : Torch_Op<"aten.arange.start_step", [ AnyTorchScalarType:$start, AnyTorchScalarType:$end, AnyTorchScalarType:$step, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -4106,7 +4106,7 @@ def Torch_AtenArgmaxOp : Torch_Op<"aten.argmax", [ let summary = "Generated op for `aten::argmax : (Tensor, int?, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dim, + AnyTorchOptionalIntType:$dim, Torch_BoolType:$keepdim ); let results = (outs @@ -4157,7 +4157,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ let summary = "Generated op for `aten::clone : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4227,12 +4227,12 @@ def Torch_Aten_ToCopyOp : Torch_Op<"aten._to_copy", [ let summary = "Generated op for `aten::_to_copy : (Tensor, int?, int?, Device?, bool?, bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, Torch_BoolType:$non_blocking, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4305,11 +4305,11 @@ def Torch_AtenEmptyLikeOp : Torch_Op<"aten.empty_like", [ let summary = "Generated op for `aten::empty_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4333,11 +4333,11 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [ let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -4361,11 +4361,11 @@ def Torch_AtenZerosLikeOp : Torch_Op<"aten.zeros_like", [ let summary = "Generated op for `aten::zeros_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4389,11 +4389,11 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [ let summary = "Generated op for `aten::ones_like : (Tensor, int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4416,12 +4416,12 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [ ]> { let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`"; let arguments = (ins - ListOfTorchIntType:$size, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4444,7 +4444,7 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [ let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$size, Torch_BoolType:$implicit ); let results = (outs @@ -4491,7 +4491,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [ let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size + AnyTorchListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -4659,7 +4659,7 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [ let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$repeats + AnyTorchListOfTorchIntType:$repeats ); let results = (outs AnyTorchTensorType:$result @@ -4682,7 +4682,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [ let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$shape + AnyTorchListOfTorchIntType:$shape ); let results = (outs AnyTorchTensorType:$result @@ -4705,8 +4705,8 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [ let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, - ListOfTorchIntType:$stride + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride ); let results = (outs AnyTorchTensorType:$result @@ -4728,8 +4728,8 @@ def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [ let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size, - TorchOptionalIntType:$memory_format + AnyTorchListOfTorchIntType:$size, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4826,7 +4826,7 @@ def Torch_AtenSumOp : Torch_Op<"aten.sum", [ let summary = "Generated op for `aten::sum : (Tensor, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -4850,9 +4850,9 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [ let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$dim, + AnyTorchListOfTorchIntType:$dim, Torch_BoolType:$keepdim, - TorchOptionalIntType:$dtype + AnyTorchOptionalIntType:$dtype ); let results = (outs AnyTorchTensorType:$result @@ -4927,7 +4927,7 @@ def Torch_AtenToDtypeOp : Torch_Op<"aten.to.dtype", [ Torch_IntType:$dtype, Torch_BoolType:$non_blocking, Torch_BoolType:$copy, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4954,7 +4954,7 @@ def Torch_AtenToOtherOp : Torch_Op<"aten.to.other", [ AnyTorchTensorType:$other, Torch_BoolType:$non_blocking, Torch_BoolType:$copy, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -4977,8 +4977,8 @@ def Torch_AtenToPrimDeviceOp : Torch_Op<"aten.to.prim_Device", [ let summary = "Generated op for `aten::to.prim_Device : (Tensor, Device?, int?, bool, bool) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - TorchOptionalDeviceType:$device, - TorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, Torch_BoolType:$non_blocking, Torch_BoolType:$copy ); @@ -5027,7 +5027,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [ let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size + AnyTorchListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -5052,7 +5052,7 @@ def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [ let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - ListOfTorchIntType:$size + AnyTorchListOfTorchIntType:$size ); let results = (outs AnyTorchTensorType:$result @@ -5176,8 +5176,8 @@ def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ let arguments = (ins AnyTorchTensorType:$self, Torch_IntType:$dim, - TorchOptionalIntType:$start, - TorchOptionalIntType:$end, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, Torch_IntType:$step ); let results = (outs @@ -5296,8 +5296,8 @@ def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ let summary = "Generated op for `aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)`"; let arguments = (ins Torch_FloatType:$t, - TorchOptionalIntType:$dtype, - TorchOptionalDeviceType:$device, + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalDeviceType:$device, Torch_BoolType:$requires_grad ); let results = (outs @@ -5416,12 +5416,12 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [ ]> { let summary = "Generated op for `aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`"; let arguments = (ins - ListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$size, AnyTorchScalarType:$fill_value, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory ); let results = (outs AnyTorchTensorType:$result @@ -5446,11 +5446,11 @@ def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$fill_value, - TorchOptionalIntType:$dtype, - TorchOptionalIntType:$layout, - TorchOptionalDeviceType:$device, - TorchOptionalBoolType:$pin_memory, - TorchOptionalIntType:$memory_format + AnyTorchOptionalIntType:$dtype, + AnyTorchOptionalIntType:$layout, + AnyTorchOptionalDeviceType:$device, + AnyTorchOptionalBoolType:$pin_memory, + AnyTorchOptionalIntType:$memory_format ); let results = (outs AnyTorchTensorType:$result @@ -5546,7 +5546,7 @@ def Torch_AtenKeysStrOp : Torch_Op<"aten.keys.str", [ Torch_DictType:$self ); let results = (outs - ListOfTorchStringType:$result + AnyTorchListOfTorchStringType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ @@ -5681,8 +5681,8 @@ def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [ ]> { let summary = "Generated op for `aten::eq.int_list : (int[], int[]) -> (bool)`"; let arguments = (ins - ListOfTorchIntType:$a, - ListOfTorchIntType:$b + AnyTorchListOfTorchIntType:$a, + AnyTorchListOfTorchIntType:$b ); let results = (outs Torch_BoolType:$result @@ -5730,8 +5730,8 @@ def Torch_AtenSliceTOp : Torch_Op<"aten.slice.t", [ let summary = "Generated op for `aten::slice.t : (t[], int?, int?, int) -> (t[])`"; let arguments = (ins AnyTorchListType:$l, - TorchOptionalIntType:$start, - TorchOptionalIntType:$end, + AnyTorchOptionalIntType:$start, + AnyTorchOptionalIntType:$end, Torch_IntType:$step ); let results = (outs @@ -5777,8 +5777,8 @@ def Torch_AtenNeIntListOp : Torch_Op<"aten.ne.int_list", [ ]> { let summary = "Generated op for `aten::ne.int_list : (int[], int[]) -> (bool)`"; let arguments = (ins - ListOfTorchIntType:$a, - ListOfTorchIntType:$b + AnyTorchListOfTorchIntType:$a, + AnyTorchListOfTorchIntType:$b ); let results = (outs Torch_BoolType:$result @@ -5888,7 +5888,7 @@ def Torch_AtenJoinOp : Torch_Op<"aten.join", [ let summary = "Generated op for `aten::join : (str, str[]) -> (str)`"; let arguments = (ins Torch_StringType:$self, - ListOfTorchStringType:$values + AnyTorchListOfTorchStringType:$values ); let results = (outs Torch_StringType:$result @@ -7124,7 +7124,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ ]> { let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`"; let arguments = (ins - ListOfTorchIntType:$self + AnyTorchListOfTorchIntType:$self ); let results = (outs Torch_IntType:$result @@ -7172,7 +7172,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [ ]> { let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`"; let arguments = (ins - ListOfTorchIntType:$self + AnyTorchListOfTorchIntType:$self ); let results = (outs Torch_IntType:$result @@ -7221,7 +7221,7 @@ def Torch_PrimRaiseExceptionOp : Torch_Op<"prim.RaiseException", [ let summary = "Generated op for `prim::RaiseException : (str, str?) -> ()`"; let arguments = (ins Torch_StringType:$msg, - TorchOptionalStringType:$cls + AnyTorchOptionalStringType:$cls ); let results = (outs ); diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index f57a6b006..c9399c969 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -951,7 +951,7 @@ def Torch_ValsemVariantAtenUniformOp: Torch_Op<"valsem.aten.uniform", [ AnyTorchTensorType:$self, Torch_FloatType:$from, Torch_FloatType:$to, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -970,7 +970,7 @@ def Torch_ValsemVariantAtenBernoulliFloatOp: Torch_Op<"valsem.aten.bernoulli.flo let arguments = (ins AnyTorchTensorType:$self, Torch_FloatType:$p, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -989,7 +989,7 @@ def Torch_ValsemVariantAtenBernoulliTensorOp: Torch_Op<"valsem.aten.bernoulli.Te let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$p, - TorchOptionalGeneratorType:$generator + AnyTorchOptionalGeneratorType:$generator ); let results = (outs AnyTorchTensorType:$result @@ -1171,7 +1171,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes", }]; let arguments = (ins - Variadic:$results + Variadic:$results ); let results = (outs); diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 04cc55010..d961bb44c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -354,14 +354,14 @@ class OptionalOf : def AnyTorchOptionalTensorType : OptionalOf; -def TorchOptionalIntType: OptionalOf; -def TorchOptionalBoolType: +def AnyTorchOptionalIntType: OptionalOf; +def AnyTorchOptionalBoolType: OptionalOf; -def TorchOptionalStringType: +def AnyTorchOptionalStringType: OptionalOf; -def TorchOptionalDeviceType: +def AnyTorchOptionalDeviceType: OptionalOf; -def TorchOptionalGeneratorType: +def AnyTorchOptionalGeneratorType: OptionalOf; def IsListTypePred : CPred<"$_self.isa<::mlir::torch::Torch::ListType>()">; @@ -371,15 +371,15 @@ class ListOf allowedTypes, string descr> : "$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()", descr, "::mlir::torch::Torch::ListType">; -def ListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; -def ListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">; -def ListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">; +def AnyTorchListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">; +def AnyTorchListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">; +def AnyTorchListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">; def AnyTorchListOfTensorType: ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">; def AnyTorchListOfOptionalTensorType : ListOf<[AnyTorchOptionalTensorType], "Any optional tensor list type (Tensor?[])">; -def OptionalListOfTorchIntType : OptionalOf; +def AnyTorchOptionalListOfTorchIntType : OptionalOf; // Note: TorchScript does not consider !torch.bool to be a Scalar. def AnyTorchScalarType : diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index 7aa71fb69..63677c1e3 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -25,12 +25,12 @@ TORCH_TYPE_TO_ODS_TYPE = { "Scalar": "AnyTorchScalarType", "Scalar?": "AnyTorchOptionalScalarType", "int": "Torch_IntType", - "int[]": "ListOfTorchIntType", - "int?": "TorchOptionalIntType", - "int[]?": "OptionalListOfTorchIntType", + "int[]": "AnyTorchListOfTorchIntType", + "int?": "AnyTorchOptionalIntType", + "int[]?": "AnyTorchOptionalListOfTorchIntType", "bool": "Torch_BoolType", - "bool[]": "ListOfTorchBoolType", - "bool?": "TorchOptionalBoolType", + "bool[]": "AnyTorchListOfTorchBoolType", + "bool?": "AnyTorchOptionalBoolType", "float": "Torch_FloatType", "t[]": "AnyTorchListType", "t": "AnyTorchType", @@ -38,12 +38,12 @@ TORCH_TYPE_TO_ODS_TYPE = { "t2": "AnyTorchType", "Any": "AnyTorchType", "Device": "Torch_DeviceType", - "Device?": "TorchOptionalDeviceType", + "Device?": "AnyTorchOptionalDeviceType", "Generator": "Torch_GeneratorType", - "Generator?": "TorchOptionalGeneratorType", + "Generator?": "AnyTorchOptionalGeneratorType", "str": "Torch_StringType", - "str?": "TorchOptionalStringType", - "str[]": "ListOfTorchStringType", + "str?": "AnyTorchOptionalStringType", + "str[]": "AnyTorchListOfTorchStringType", "Dict": "Torch_DictType", "__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType", }