mirror of https://github.com/llvm/torch-mlir
Add non-RNG aten ops to aten dialect.
This commit adds the aten ops which do not require random number support to aten dialect. This commit also adds some of the missing torch types. Signed-Off-By: Prateek Gupta <prateek@nod-labs.com>pull/524/head snapshot-20220114.205
parent
abd61b4974
commit
c9a343267c
File diff suppressed because it is too large
Load Diff
|
@ -378,6 +378,8 @@ class OptionalOf<Type type, string descr> :
|
|||
def AnyTorchOptionalTensorType :
|
||||
OptionalOf<AnyTorchTensorType, "Optional torch tensor type">;
|
||||
def TorchOptionalIntType: OptionalOf<Torch_IntType, "Optional torch int type">;
|
||||
def TorchOptionalFloatType: OptionalOf<Torch_FloatType, "Optional torch float type">;
|
||||
def TorchOptionalStringType: OptionalOf<Torch_StringType, "Optional torch string type">;
|
||||
def TorchOptionalBoolType:
|
||||
OptionalOf<Torch_BoolType, "Optional torch bool type">;
|
||||
def TorchOptionalDeviceType:
|
||||
|
@ -392,10 +394,16 @@ class ListOf<list<Type> allowedTypes, string descr> :
|
|||
|
||||
def TorchBoolListType : ListOf<[Torch_BoolType], "Bool list type (bool[])">;
|
||||
def TorchIntListType : ListOf<[Torch_IntType], "Int list type (int[])">;
|
||||
def TorchFloatListType: ListOf<[Torch_FloatType], "Float list type (float[])">;
|
||||
def TorchStringListType : ListOf<[Torch_StringType], "Str list type (str[])">;
|
||||
|
||||
def TorchOptionalIntListType : OptionalOf<TorchIntListType, "Optional torch int list type (int[]?)">;
|
||||
def TorchOptionalFloatListType : OptionalOf<TorchFloatListType, "Optional torch float list type (float[]?)">;
|
||||
def TorchOptionalStringListType : OptionalOf<TorchStringListType, "Optional torch string list type (string[]?)">;
|
||||
|
||||
def AnyTorchTensorListType:
|
||||
ListOf<[AnyTorchTensorType], "Any int list type (Tensor[])">;
|
||||
def AnyTorchOptionalTensorListType :
|
||||
def AnyTorchListOptionalTensorType :
|
||||
ListOf<[AnyTorchOptionalTensorType],
|
||||
"Any optional tensor list type (Tensor?[])">;
|
||||
|
||||
|
@ -407,6 +415,8 @@ def AnyTorchScalarType : AnyTypeOf<[
|
|||
], "Any Python numeric type compatible with being the scalar type of a tensor (`Scalar`)">;
|
||||
def AnyTorchOptionalScalarType:
|
||||
OptionalOf<AnyTorchScalarType, "Optional torch scalar type">;
|
||||
def AnyTorchScalarListType:
|
||||
ListOf<[AnyTorchScalarType], "Any list type (Scalar[])">;
|
||||
|
||||
// See function `DictTypePtr create(TypePtr key, TypePtr value)`
|
||||
// in aten/src/ATen/core/jit_type.h.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue