Add non-RNG aten ops to aten dialect.

This commit adds the aten ops which do not require random number
support to aten dialect. This commit also adds some of the missing
torch types.

Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>
pull/524/head snapshot-20220114.205
Prateek Gupta 2022-01-11 20:59:05 +00:00
parent abd61b4974
commit c9a343267c
3 changed files with 34125 additions and 6 deletions

File diff suppressed because it is too large Load Diff

View File

@ -378,6 +378,8 @@ class OptionalOf<Type type, string descr> :
def AnyTorchOptionalTensorType :
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
def TorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
def TorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
def TorchOptionalStringType: OptionalOf<Torch_StringType, "Optional torch string type">;
def TorchOptionalBoolType:
OptionalOf<Torch_BoolType, "Optional torch bool type">;
def TorchOptionalDeviceType:
@ -392,10 +394,16 @@ class ListOf<list<Type> allowedTypes, string descr> :
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
def TorchFloatListType: ListOf<[Torch_FloatType], "Float list type (float[])">;
def TorchStringListType : ListOf<[Torch_StringType], "Str list type (str[])">;
def TorchOptionalIntListType : OptionalOf<TorchIntListType, "Optional torch int list type (int[]?)">;
def TorchOptionalFloatListType : OptionalOf<TorchFloatListType, "Optional torch float list type (float[]?)">;
def TorchOptionalStringListType : OptionalOf<TorchStringListType, "Optional torch string list type (string[]?)">;
def AnyTorchTensorListType:
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
def AnyTorchOptionalTensorListType :
def AnyTorchListOptionalTensorType :
ListOf<[AnyTorchOptionalTensorType],
"Any optional tensor list type (Tensor?[])">;
@ -407,6 +415,8 @@ def AnyTorchScalarType : AnyTypeOf<[
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
def AnyTorchOptionalScalarType:
OptionalOf<AnyTorchScalarType, "Optional torch scalar type">;
def AnyTorchScalarListType:
ListOf<[AnyTorchScalarType], "Any list type (Scalar[])">;
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
// in aten/src/ATen/core/jit_type.h.