diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 4524b9d5a..c852dd613 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -34,6 +34,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); +/// Gets the !torch.nn.Module typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.optional type. //===----------------------------------------------------------------------===// @@ -49,6 +52,9 @@ torchMlirTorchOptionalTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchOptionalTypeGetContained(MlirType containedType); +/// Gets the !torch.optional typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.tuple type. //===----------------------------------------------------------------------===// @@ -65,7 +71,11 @@ torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t); /// Returns the pos-th type in the !torch.tuple type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); +MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, + intptr_t pos); + +/// Gets the !torch.tuple typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.union type. @@ -83,7 +93,11 @@ torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t); /// Returns the pos-th type in the !torch.union type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); +MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, + intptr_t pos); + +/// Gets the !torch.union typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.list type. @@ -98,6 +112,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); /// Gets contained T in a !torch.list type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); +/// Gets the !torch.list typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -108,6 +125,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t); /// Gets the !torch.Device type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); +/// Gets the !torch.device typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.Generator type. //===----------------------------------------------------------------------===// @@ -118,6 +138,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t); /// Gets the !torch.Generator type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); +/// Gets the !torch.generator typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.bool type. //===----------------------------------------------------------------------===// @@ -128,6 +151,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t); /// Gets the !torch.bool type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); +/// Gets the !torch.bool typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.int type. //===----------------------------------------------------------------------===// @@ -138,6 +164,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t); /// Gets the !torch.int type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); +/// Gets the !torch.int typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.float type. //===----------------------------------------------------------------------===// @@ -148,6 +177,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t); /// Gets the !torch.float type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); +/// Gets the !torch.float typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.LinearParams type. //===----------------------------------------------------------------------===// @@ -159,6 +191,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context); +/// Gets the !torch.linearparams typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.qint8 type. //===----------------------------------------------------------------------===// @@ -169,6 +204,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); /// Gets the !torch.qint8 type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); +/// Gets the !torch.qint8 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.quint8 type. //===----------------------------------------------------------------------===// @@ -179,6 +217,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); /// Gets the !torch.quint8 type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); +/// Gets the !torch.quint8 typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(); + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// @@ -217,10 +258,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t); /// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size /// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); +MLIR_CAPI_EXPORTED int64_t +torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); /// Gets the the dtype (data type) of a !torch.tensor. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); +MLIR_CAPI_EXPORTED MlirType +torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); + +/// Gets the !torch.tensor typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(); //===----------------------------------------------------------------------===// // torch.vtensor type. @@ -259,11 +305,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t); /// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size /// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); +MLIR_CAPI_EXPORTED int64_t +torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); /// Gets the the dtype (data type) of a !torch.vtensor. MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); +/// Gets the !torch.vtensor typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.none type. //===----------------------------------------------------------------------===// @@ -274,6 +324,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t); /// Gets the !torch.none type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); +/// Gets the !torch.none typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.str type. //===----------------------------------------------------------------------===// @@ -284,6 +337,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t); /// Gets the !torch.str type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); +/// Gets the !torch.str typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.any type. //===----------------------------------------------------------------------===// @@ -294,6 +350,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t); /// Gets the !torch.str type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); +/// Gets the !torch.any typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.number type. //===----------------------------------------------------------------------===// @@ -304,6 +363,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t); /// Gets the !torch.number type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); +/// Gets the !torch.number typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(); + //===----------------------------------------------------------------------===// // !torch.dict type. //===----------------------------------------------------------------------===// @@ -324,6 +386,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t); /// Gets the value type of a !torch.dict type. MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); +/// Gets the !torch.dict typeid. +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(); + #ifdef __cplusplus } #endif diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 76ae43c2c..f4a9ca032 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -34,6 +34,10 @@ MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); } +MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() { + return wrap(Torch::NnModuleType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.optional type. //===----------------------------------------------------------------------===// @@ -47,8 +51,12 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { } MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getContainedType()); + auto type = unwrap(t).cast(); + return wrap(type.getContainedType()); +} + +MlirTypeID torchMlirTorchOptionalTypeGetTypeID() { + return wrap(Torch::OptionalType::getTypeID()); } //===----------------------------------------------------------------------===// @@ -63,10 +71,9 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes) { return wrap(Torch::TupleType::get( - unwrap(context), - llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), - [](MlirType t) { return unwrap(t); })))); + unwrap(context), llvm::to_vector<6>(llvm::map_range( + llvm::ArrayRef(containedTypes, numContainedTypes), + [](MlirType t) { return unwrap(t); })))); } size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) { @@ -79,6 +86,10 @@ MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) { return wrap(type.getContainedTypes()[pos]); } +MlirTypeID torchMlirTorchTupleTypeGetTypeID() { + return wrap(Torch::TupleType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.union type. //===----------------------------------------------------------------------===// @@ -91,10 +102,9 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes) { return wrap(Torch::UnionType::get( - unwrap(context), - llvm::to_vector<6>( - llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes), - [](MlirType t) { return unwrap(t); })))); + unwrap(context), llvm::to_vector<6>(llvm::map_range( + llvm::ArrayRef(containedTypes, numContainedTypes), + [](MlirType t) { return unwrap(t); })))); } size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) { @@ -107,6 +117,10 @@ MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) { return wrap(type.getContainedTypes()[pos]); } +MlirTypeID torchMlirTorchUnionTypeGetTypeID() { + return wrap(Torch::UnionType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// @@ -123,6 +137,10 @@ MlirType torchMlirTorchListTypeGetContainedType(MlirType t) { return wrap(unwrap(t).cast().getContainedType()); } +MlirTypeID torchMlirTorchListTypeGetTypeID() { + return wrap(Torch::ListType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -135,6 +153,10 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) { return wrap(Torch::DeviceType::get(unwrap(context))); } +MlirTypeID torchMlirTorchDeviceTypeGetTypeID() { + return wrap(Torch::DeviceType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Generator type. //===----------------------------------------------------------------------===// @@ -147,6 +169,10 @@ MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) { return wrap(Torch::GeneratorType::get(unwrap(context))); } +MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() { + return wrap(Torch::GeneratorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.bool type. //===----------------------------------------------------------------------===// @@ -159,6 +185,10 @@ MlirType torchMlirTorchBoolTypeGet(MlirContext context) { return wrap(Torch::BoolType::get(unwrap(context))); } +MlirTypeID torchMlirTorchBoolTypeGetTypeID() { + return wrap(Torch::BoolType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.int type. //===----------------------------------------------------------------------===// @@ -171,6 +201,10 @@ MlirType torchMlirTorchIntTypeGet(MlirContext context) { return wrap(Torch::IntType::get(unwrap(context))); } +MlirTypeID torchMlirTorchIntTypeGetTypeID() { + return wrap(Torch::IntType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.float type. //===----------------------------------------------------------------------===// @@ -183,6 +217,10 @@ MlirType torchMlirTorchFloatTypeGet(MlirContext context) { return wrap(Torch::FloatType::get(unwrap(context))); } +MlirTypeID torchMlirTorchFloatTypeGetTypeID() { + return wrap(Torch::FloatType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.LinearParams type. //===----------------------------------------------------------------------===// @@ -195,6 +233,10 @@ MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) { return wrap(Torch::LinearParamsType::get(unwrap(context))); } +MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() { + return wrap(Torch::LinearParamsType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.qint8 type. //===----------------------------------------------------------------------===// @@ -207,6 +249,10 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) { return wrap(Torch::QInt8Type::get(unwrap(context))); } +MlirTypeID torchMlirTorchQInt8TypeGetTypeID() { + return wrap(Torch::QInt8Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.quint8 type. //===----------------------------------------------------------------------===// @@ -219,6 +265,10 @@ MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) { return wrap(Torch::QUInt8Type::get(unwrap(context))); } +MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() { + return wrap(Torch::QUInt8Type::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.tensor type. //===----------------------------------------------------------------------===// @@ -258,11 +308,11 @@ int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) { } bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return unwrap(t).cast().hasSizes(); } bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return unwrap(t).cast().hasDtype(); } int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { @@ -282,6 +332,10 @@ MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) { return wrap(unwrap(t).cast().getDtype()); } +MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() { + return wrap(Torch::NonValueTensorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.vtensor type. //===----------------------------------------------------------------------===// @@ -321,11 +375,11 @@ int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) { } bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); + return unwrap(t).cast().hasSizes(); } bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); + return unwrap(t).cast().hasDtype(); } int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { @@ -345,6 +399,10 @@ MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) { return wrap(unwrap(t).cast().getDtype()); } +MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() { + return wrap(Torch::ValueTensorType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.none type. //===----------------------------------------------------------------------===// @@ -357,6 +415,10 @@ MlirType torchMlirTorchNoneTypeGet(MlirContext context) { return wrap(Torch::NoneType::get(unwrap(context))); } +MlirTypeID torchMlirTorchNoneTypeGetTypeID() { + return wrap(Torch::NoneType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.str type. //===----------------------------------------------------------------------===// @@ -369,6 +431,10 @@ MlirType torchMlirTorchStringTypeGet(MlirContext context) { return wrap(Torch::StringType::get(unwrap(context))); } +MlirTypeID torchMlirTorchStringTypeGetTypeID() { + return wrap(Torch::StringType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.any type. //===----------------------------------------------------------------------===// @@ -381,6 +447,10 @@ MlirType torchMlirTorchAnyTypeGet(MlirContext context) { return wrap(Torch::AnyType::get(unwrap(context))); } +MlirTypeID torchMlirTorchAnyTypeGetTypeID() { + return wrap(Torch::AnyType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.number type. //===----------------------------------------------------------------------===// @@ -393,6 +463,10 @@ MlirType torchMlirTorchNumberTypeGet(MlirContext context) { return wrap(Torch::NumberType::get(unwrap(context))); } +MlirTypeID torchMlirTorchNumberTypeGetTypeID() { + return wrap(Torch::NumberType::getTypeID()); +} + //===----------------------------------------------------------------------===// // torch.Dict type. //===----------------------------------------------------------------------===// @@ -413,11 +487,15 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType, } MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getKeyType()); + auto type = unwrap(t).cast(); + return wrap(type.getKeyType()); } MlirType torchMlirTorchDictTypeGetValueType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getValueType()); + auto type = unwrap(t).cast(); + return wrap(type.getValueType()); +} + +MlirTypeID torchMlirTorchDictTypeGetTypeID() { + return wrap(Torch::DictType::getTypeID()); }