Add typeids to CAPI. (#2253)

pull/2254/head
Maksim Levental 2023-06-20 22:06:43 -05:00 committed by GitHub
parent ebda611100
commit 0244f540a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 23 deletions

View File

@ -34,6 +34,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
/// Gets the !torch.nn.Module typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.optional type.
//===----------------------------------------------------------------------===//
@ -49,6 +52,9 @@ torchMlirTorchOptionalTypeGet(MlirType containedType);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
/// Gets the !torch.optional typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
@ -65,7 +71,11 @@ torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.tuple type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t,
intptr_t pos);
/// Gets the !torch.tuple typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type.
@ -83,7 +93,11 @@ torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.union type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t,
intptr_t pos);
/// Gets the !torch.union typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.list<T> type.
@ -98,6 +112,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
/// Gets contained T in a !torch.list<T> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
/// Gets the !torch.list typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.Device type.
//===----------------------------------------------------------------------===//
@ -108,6 +125,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
/// Gets the !torch.device typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.Generator type.
//===----------------------------------------------------------------------===//
@ -118,6 +138,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
/// Gets the !torch.Generator type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
/// Gets the !torch.generator typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.bool type.
//===----------------------------------------------------------------------===//
@ -128,6 +151,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
/// Gets the !torch.bool type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
/// Gets the !torch.bool typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.int type.
//===----------------------------------------------------------------------===//
@ -138,6 +164,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
/// Gets the !torch.int type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
/// Gets the !torch.int typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.float type.
//===----------------------------------------------------------------------===//
@ -148,6 +177,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
/// Gets the !torch.float type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
/// Gets the !torch.float typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
@ -159,6 +191,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchLinearParamsTypeGet(MlirContext context);
/// Gets the !torch.linearparams typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.qint8 type.
//===----------------------------------------------------------------------===//
@ -169,6 +204,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
/// Gets the !torch.qint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.quint8 type.
//===----------------------------------------------------------------------===//
@ -179,6 +217,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
/// Gets the !torch.quint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
@ -217,10 +258,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
/// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
MLIR_CAPI_EXPORTED int64_t
torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.tensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
/// Gets the !torch.tensor typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID();
//===----------------------------------------------------------------------===//
// torch.vtensor type.
@ -259,11 +305,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
/// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
MLIR_CAPI_EXPORTED int64_t
torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.vtensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
/// Gets the !torch.vtensor typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID();
//===----------------------------------------------------------------------===//
// !torch.none type.
//===----------------------------------------------------------------------===//
@ -274,6 +324,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
/// Gets the !torch.none type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
/// Gets the !torch.none typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID();
//===----------------------------------------------------------------------===//
// !torch.str type.
//===----------------------------------------------------------------------===//
@ -284,6 +337,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
/// Gets the !torch.str typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID();
//===----------------------------------------------------------------------===//
// !torch.any type.
//===----------------------------------------------------------------------===//
@ -294,6 +350,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
/// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
/// Gets the !torch.any typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID();
//===----------------------------------------------------------------------===//
// !torch.number type.
//===----------------------------------------------------------------------===//
@ -304,6 +363,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
/// Gets the !torch.number type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
/// Gets the !torch.number typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID();
//===----------------------------------------------------------------------===//
// !torch.dict type.
//===----------------------------------------------------------------------===//
@ -324,6 +386,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
/// Gets the value type of a !torch.dict<key, value> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
/// Gets the !torch.dict typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID();
#ifdef __cplusplus
}
#endif

View File

@ -34,6 +34,10 @@ MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
}
MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
return wrap(Torch::NnModuleType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.optional type.
//===----------------------------------------------------------------------===//
@ -47,8 +51,12 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
}
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
auto type = unwrap(t).cast<Torch::OptionalType>();
return wrap(type.getContainedType());
auto type = unwrap(t).cast<Torch::OptionalType>();
return wrap(type.getContainedType());
}
MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
return wrap(Torch::OptionalType::getTypeID());
}
//===----------------------------------------------------------------------===//
@ -63,10 +71,9 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::TupleType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
unwrap(context), llvm::to_vector<6>(llvm::map_range(
llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
@ -79,6 +86,10 @@ MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
return wrap(type.getContainedTypes()[pos]);
}
MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
return wrap(Torch::TupleType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===//
@ -91,10 +102,9 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
intptr_t numContainedTypes,
MlirType const *containedTypes) {
return wrap(Torch::UnionType::get(
unwrap(context),
llvm::to_vector<6>(
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
unwrap(context), llvm::to_vector<6>(llvm::map_range(
llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); }))));
}
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
@ -107,6 +117,10 @@ MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
return wrap(type.getContainedTypes()[pos]);
}
MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
return wrap(Torch::UnionType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
@ -123,6 +137,10 @@ MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
}
MlirTypeID torchMlirTorchListTypeGetTypeID() {
return wrap(Torch::ListType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.Device type.
//===----------------------------------------------------------------------===//
@ -135,6 +153,10 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
return wrap(Torch::DeviceType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.Generator type.
//===----------------------------------------------------------------------===//
@ -147,6 +169,10 @@ MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
return wrap(Torch::GeneratorType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
return wrap(Torch::GeneratorType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.bool type.
//===----------------------------------------------------------------------===//
@ -159,6 +185,10 @@ MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
return wrap(Torch::BoolType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
return wrap(Torch::BoolType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.int type.
//===----------------------------------------------------------------------===//
@ -171,6 +201,10 @@ MlirType torchMlirTorchIntTypeGet(MlirContext context) {
return wrap(Torch::IntType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchIntTypeGetTypeID() {
return wrap(Torch::IntType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.float type.
//===----------------------------------------------------------------------===//
@ -183,6 +217,10 @@ MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
return wrap(Torch::FloatType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
return wrap(Torch::FloatType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
@ -195,6 +233,10 @@ MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
return wrap(Torch::LinearParamsType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.qint8 type.
//===----------------------------------------------------------------------===//
@ -207,6 +249,10 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context)));
}
MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
return wrap(Torch::QInt8Type::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.quint8 type.
//===----------------------------------------------------------------------===//
@ -219,6 +265,10 @@ MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
return wrap(Torch::QUInt8Type::get(unwrap(context)));
}
MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
return wrap(Torch::QUInt8Type::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
@ -258,11 +308,11 @@ int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
}
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
}
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
}
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
@ -282,6 +332,10 @@ MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
}
MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
return wrap(Torch::NonValueTensorType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.vtensor type.
//===----------------------------------------------------------------------===//
@ -321,11 +375,11 @@ int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
}
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
}
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
}
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
@ -345,6 +399,10 @@ MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
}
MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
return wrap(Torch::ValueTensorType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.none type.
//===----------------------------------------------------------------------===//
@ -357,6 +415,10 @@ MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
return wrap(Torch::NoneType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.str type.
//===----------------------------------------------------------------------===//
@ -369,6 +431,10 @@ MlirType torchMlirTorchStringTypeGet(MlirContext context) {
return wrap(Torch::StringType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchStringTypeGetTypeID() {
return wrap(Torch::StringType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.any type.
//===----------------------------------------------------------------------===//
@ -381,6 +447,10 @@ MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
return wrap(Torch::AnyType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
return wrap(Torch::AnyType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.number type.
//===----------------------------------------------------------------------===//
@ -393,6 +463,10 @@ MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
return wrap(Torch::NumberType::get(unwrap(context)));
}
MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
return wrap(Torch::NumberType::getTypeID());
}
//===----------------------------------------------------------------------===//
// torch.Dict type.
//===----------------------------------------------------------------------===//
@ -413,11 +487,15 @@ MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
}
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getKeyType());
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getKeyType());
}
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getValueType());
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getValueType());
}
MlirTypeID torchMlirTorchDictTypeGetTypeID() {
return wrap(Torch::DictType::getTypeID());
}