mirror of https://github.com/llvm/torch-mlir
Rename optional list types (#643)
parent
e7721fb784
commit
fa0b24a73c
|
@ -2556,7 +2556,7 @@ def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [
|
|||
let summary = "Generated op for `aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchListOfOptionalTensorType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
|
@ -2581,7 +2581,7 @@ def Torch_AtenIndexPut_Op : Torch_Op<"aten.index_put_", [
|
|||
let summary = "Generated op for `aten::index_put_ : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchListOfOptionalTensorType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate
|
||||
);
|
||||
|
@ -2709,9 +2709,9 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [
|
|||
AnyTorchTensorType:$input,
|
||||
AnyTorchTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
TorchIntListType:$stride,
|
||||
TorchIntListType:$padding,
|
||||
TorchIntListType:$dilation,
|
||||
ListOfTorchIntType:$stride,
|
||||
ListOfTorchIntType:$padding,
|
||||
ListOfTorchIntType:$dilation,
|
||||
Torch_IntType:$groups
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -2799,7 +2799,7 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [
|
|||
let summary = "Generated op for `aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
TorchIntListType:$normalized_shape,
|
||||
ListOfTorchIntType:$normalized_shape,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
Torch_FloatType:$eps,
|
||||
|
@ -2827,7 +2827,7 @@ def Torch_AtenNativeLayerNormOp : Torch_Op<"aten.native_layer_norm", [
|
|||
let summary = "Generated op for `aten::native_layer_norm : (Tensor, int[], Tensor?, Tensor?, float) -> (Tensor, Tensor, Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$input,
|
||||
TorchIntListType:$normalized_shape,
|
||||
ListOfTorchIntType:$normalized_shape,
|
||||
AnyTorchOptionalTensorType:$weight,
|
||||
AnyTorchOptionalTensorType:$bias,
|
||||
Torch_FloatType:$eps
|
||||
|
@ -2856,10 +2856,10 @@ def Torch_AtenMaxPool2dOp : Torch_Op<"aten.max_pool2d", [
|
|||
let summary = "Generated op for `aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$kernel_size,
|
||||
TorchIntListType:$stride,
|
||||
TorchIntListType:$padding,
|
||||
TorchIntListType:$dilation,
|
||||
ListOfTorchIntType:$kernel_size,
|
||||
ListOfTorchIntType:$stride,
|
||||
ListOfTorchIntType:$padding,
|
||||
ListOfTorchIntType:$dilation,
|
||||
Torch_BoolType:$ceil_mode
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -2959,7 +2959,7 @@ def Torch_AtenAdaptiveAvgPool2dOp : Torch_Op<"aten.adaptive_avg_pool2d", [
|
|||
let summary = "Generated op for `aten::adaptive_avg_pool2d : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$output_size
|
||||
ListOfTorchIntType:$output_size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -3034,7 +3034,7 @@ def Torch_AtenPermuteOp : Torch_Op<"aten.permute", [
|
|||
let summary = "Generated op for `aten::permute : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$dims
|
||||
ListOfTorchIntType:$dims
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -3131,7 +3131,7 @@ def Torch_AtenLogsumexpOp : Torch_Op<"aten.logsumexp", [
|
|||
let summary = "Generated op for `aten::logsumexp : (Tensor, int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$dim,
|
||||
ListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -3156,7 +3156,7 @@ def Torch_AtenMeanDimOp : Torch_Op<"aten.mean.dim", [
|
|||
let summary = "Generated op for `aten::mean.dim : (Tensor, int[], bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$dim,
|
||||
ListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim,
|
||||
TorchOptionalIntType:$dtype
|
||||
);
|
||||
|
@ -3408,7 +3408,7 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [
|
|||
let summary = "Generated op for `aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$pad,
|
||||
ListOfTorchIntType:$pad,
|
||||
AnyTorchScalarType:$value
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -3553,7 +3553,7 @@ def Torch_AtenSizeOp : Torch_Op<"aten.size", [
|
|||
AnyTorchTensorType:$self
|
||||
);
|
||||
let results = (outs
|
||||
TorchIntListType:$result
|
||||
ListOfTorchIntType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
|
@ -3597,7 +3597,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -3625,7 +3625,7 @@ def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [
|
|||
let summary = "Generated op for `aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -3652,7 +3652,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -3701,7 +3701,7 @@ def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [
|
|||
let summary = "Generated op for `aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -4212,7 +4212,7 @@ def Torch_AtenNewEmptyOp : Torch_Op<"aten.new_empty", [
|
|||
let summary = "Generated op for `aten::new_empty : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -4295,7 +4295,7 @@ def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
TorchOptionalDeviceType:$device,
|
||||
|
@ -4323,7 +4323,7 @@ def Torch_AtenExpandOp : Torch_Op<"aten.expand", [
|
|||
let summary = "Generated op for `aten::expand : (Tensor, int[], bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
Torch_BoolType:$implicit
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -4370,7 +4370,7 @@ def Torch_AtenBroadcastToOp : Torch_Op<"aten.broadcast_to", [
|
|||
let summary = "Generated op for `aten::broadcast_to : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size
|
||||
ListOfTorchIntType:$size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4394,7 +4394,7 @@ def Torch_AtenIndexTensorOp : Torch_Op<"aten.index.Tensor", [
|
|||
let summary = "Generated op for `aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices
|
||||
AnyTorchListOfOptionalTensorType:$indices
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4441,7 +4441,7 @@ def Torch_Aten_IndexPutImpl_Op : Torch_Op<"aten._index_put_impl_", [
|
|||
let summary = "Generated op for `aten::_index_put_impl_ : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchListOfOptionalTensorType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate,
|
||||
Torch_BoolType:$unsafe
|
||||
|
@ -4538,7 +4538,7 @@ def Torch_AtenRepeatOp : Torch_Op<"aten.repeat", [
|
|||
let summary = "Generated op for `aten::repeat : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$repeats
|
||||
ListOfTorchIntType:$repeats
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4561,7 +4561,7 @@ def Torch_AtenReshapeOp : Torch_Op<"aten.reshape", [
|
|||
let summary = "Generated op for `aten::reshape : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$shape
|
||||
ListOfTorchIntType:$shape
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4584,8 +4584,8 @@ def Torch_Aten_ReshapeAliasOp : Torch_Op<"aten._reshape_alias", [
|
|||
let summary = "Generated op for `aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
TorchIntListType:$stride
|
||||
ListOfTorchIntType:$size,
|
||||
ListOfTorchIntType:$stride
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4607,7 +4607,7 @@ def Torch_AtenResize_Op : Torch_Op<"aten.resize_", [
|
|||
let summary = "Generated op for `aten::resize_ : (Tensor, int[], int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
TorchOptionalIntType:$memory_format
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -4680,7 +4680,7 @@ def Torch_AtenStackOp : Torch_Op<"aten.stack", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::stack : (Tensor[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorListType:$tensors,
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -4729,7 +4729,7 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
|
|||
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$dim,
|
||||
ListOfTorchIntType:$dim,
|
||||
Torch_BoolType:$keepdim,
|
||||
TorchOptionalIntType:$dtype
|
||||
);
|
||||
|
@ -4906,7 +4906,7 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
|
|||
let summary = "Generated op for `aten::view : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size
|
||||
ListOfTorchIntType:$size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -4931,7 +4931,7 @@ def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
|
|||
let summary = "Generated op for `aten::_unsafe_view : (Tensor, int[]) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
TorchIntListType:$size
|
||||
ListOfTorchIntType:$size
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchTensorType:$result
|
||||
|
@ -5295,7 +5295,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$size,
|
||||
ListOfTorchIntType:$size,
|
||||
AnyTorchScalarType:$fill_value,
|
||||
TorchOptionalIntType:$dtype,
|
||||
TorchOptionalIntType:$layout,
|
||||
|
@ -5425,7 +5425,7 @@ def Torch_AtenKeysStrOp : Torch_Op<"aten.keys.str", [
|
|||
Torch_DictType:$self
|
||||
);
|
||||
let results = (outs
|
||||
TorchStringListType:$result
|
||||
ListOfTorchStringType:$result
|
||||
);
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDefinition = [{
|
||||
|
@ -5490,7 +5490,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::cat : (Tensor[], int) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorListType:$tensors,
|
||||
AnyTorchListOfTensorType:$tensors,
|
||||
Torch_IntType:$dim
|
||||
);
|
||||
let results = (outs
|
||||
|
@ -5560,8 +5560,8 @@ def Torch_AtenEqIntListOp : Torch_Op<"aten.eq.int_list", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::eq.int_list : (int[], int[]) -> (bool)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$a,
|
||||
TorchIntListType:$b
|
||||
ListOfTorchIntType:$a,
|
||||
ListOfTorchIntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
|
@ -5656,8 +5656,8 @@ def Torch_AtenNeIntListOp : Torch_Op<"aten.ne.int_list", [
|
|||
]> {
|
||||
let summary = "Generated op for `aten::ne.int_list : (int[], int[]) -> (bool)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$a,
|
||||
TorchIntListType:$b
|
||||
ListOfTorchIntType:$a,
|
||||
ListOfTorchIntType:$b
|
||||
);
|
||||
let results = (outs
|
||||
Torch_BoolType:$result
|
||||
|
@ -5767,7 +5767,7 @@ def Torch_AtenJoinOp : Torch_Op<"aten.join", [
|
|||
let summary = "Generated op for `aten::join : (str, str[]) -> (str)`";
|
||||
let arguments = (ins
|
||||
Torch_StringType:$self,
|
||||
TorchStringListType:$values
|
||||
ListOfTorchStringType:$values
|
||||
);
|
||||
let results = (outs
|
||||
Torch_StringType:$result
|
||||
|
@ -6905,7 +6905,7 @@ def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [
|
|||
]> {
|
||||
let summary = "Generated op for `prim::min.self_int : (int[]) -> (int)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$self
|
||||
ListOfTorchIntType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
|
@ -6953,7 +6953,7 @@ def Torch_PrimMaxSelfIntOp : Torch_Op<"prim.max.self_int", [
|
|||
]> {
|
||||
let summary = "Generated op for `prim::max.self_int : (int[]) -> (int)`";
|
||||
let arguments = (ins
|
||||
TorchIntListType:$self
|
||||
ListOfTorchIntType:$self
|
||||
);
|
||||
let results = (outs
|
||||
Torch_IntType:$result
|
||||
|
|
|
@ -1042,7 +1042,7 @@ def Torch_ValsemVariantAtenIndexPutImplOp: Torch_Op<"valsem.aten.index_put_impl"
|
|||
let summary = "`index_put_impl op : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)`";
|
||||
let arguments = (ins
|
||||
AnyTorchTensorType:$self,
|
||||
AnyTorchOptionalTensorListType:$indices,
|
||||
AnyTorchListOfOptionalTensorType:$indices,
|
||||
AnyTorchTensorType:$values,
|
||||
Torch_BoolType:$accumulate,
|
||||
Torch_BoolType:$unsafe
|
||||
|
@ -1171,7 +1171,7 @@ def Torch_ShapeCalculateYieldShapesOp : Torch_Op<"shape.calculate.yield.shapes",
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Variadic<TorchIntListType>:$results
|
||||
Variadic<ListOfTorchIntType>:$results
|
||||
);
|
||||
let results = (outs);
|
||||
|
||||
|
|
|
@ -425,14 +425,15 @@ class ListOf<list<Type> allowedTypes, string descr> :
|
|||
"$_self.cast<::mlir::torch::Torch::ListType>().getContainedType()",
|
||||
descr, "::mlir::torch::Torch::ListType">;
|
||||
|
||||
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
||||
def TorchStringListType : ListOf<[Torch_StringType], "Str list type (str[])">;
|
||||
def AnyTorchTensorListType:
|
||||
def ListOfTorchBoolType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
def ListOfTorchIntType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
||||
def ListOfTorchStringType : ListOf<[Torch_StringType], "Str list type (str[])">;
|
||||
def AnyTorchListOfTensorType:
|
||||
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
|
||||
def AnyTorchOptionalTensorListType :
|
||||
def AnyTorchListOfOptionalTensorType :
|
||||
ListOf<[AnyTorchOptionalTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
def OptionalListOfTorchIntType : OptionalOf<ListOfTorchIntType, "Optional torch int list type (int[]?)">;
|
||||
|
||||
// Note: TorchScript does not consider !torch.bool to be a Scalar.
|
||||
def AnyTorchScalarType :
|
||||
|
|
|
@ -20,15 +20,16 @@ from .registry import Registry, JitOperator
|
|||
TORCH_TYPE_TO_ODS_TYPE = {
|
||||
"Tensor": "AnyTorchTensorType",
|
||||
"Tensor?": "AnyTorchOptionalTensorType",
|
||||
"Tensor?[]": "AnyTorchOptionalTensorListType",
|
||||
"Tensor[]": "AnyTorchTensorListType",
|
||||
"Tensor?[]": "AnyTorchListOfOptionalTensorType",
|
||||
"Tensor[]": "AnyTorchListOfTensorType",
|
||||
"Scalar": "AnyTorchScalarType",
|
||||
"Scalar?": "AnyTorchOptionalScalarType",
|
||||
"int": "Torch_IntType",
|
||||
"int[]": "TorchIntListType",
|
||||
"int[]": "ListOfTorchIntType",
|
||||
"int?": "TorchOptionalIntType",
|
||||
"int[]?": "OptionalListOfTorchIntType",
|
||||
"bool": "Torch_BoolType",
|
||||
"bool[]": "TorchBoolListType",
|
||||
"bool[]": "ListOfTorchBoolType",
|
||||
"bool?": "TorchOptionalBoolType",
|
||||
"float": "Torch_FloatType",
|
||||
"t[]": "AnyTorchListType",
|
||||
|
@ -42,7 +43,7 @@ TORCH_TYPE_TO_ODS_TYPE = {
|
|||
"Generator?": "TorchOptionalGeneratorType",
|
||||
"str": "Torch_StringType",
|
||||
"str?": "TorchOptionalStringType",
|
||||
"str[]": "TorchStringListType",
|
||||
"str[]": "ListOfTorchStringType",
|
||||
"Dict": "Torch_DictType",
|
||||
"__torch__.torch.classes.quantized.LinearPackedParamsBase": "Torch_LinearParamsType",
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue