Add typeids to CAPI. (#2253)

pull/2254/head
Maksim Levental 2023-06-20 22:06:43 -05:00 committed by GitHub
parent ebda611100
commit 0244f540a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 166 additions and 23 deletions

View File

@ -34,6 +34,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
/// Gets the !torch.nn.Module typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.optional type. // torch.optional type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -49,6 +52,9 @@ torchMlirTorchOptionalTypeGet(MlirType containedType);
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGetContained(MlirType containedType); torchMlirTorchOptionalTypeGetContained(MlirType containedType);
/// Gets the !torch.optional typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -65,7 +71,11 @@ torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t); MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.tuple type. /// Returns the pos-th type in the !torch.tuple type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t,
intptr_t pos);
/// Gets the !torch.tuple typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type. // torch.union<T1, T2, T3> type.
@ -83,7 +93,11 @@ torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t); MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.union type. /// Returns the pos-th type in the !torch.union type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t,
intptr_t pos);
/// Gets the !torch.union typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.list<T> type. // torch.list<T> type.
@ -98,6 +112,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
/// Gets contained T in a !torch.list<T> type. /// Gets contained T in a !torch.list<T> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
/// Gets the !torch.list typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -108,6 +125,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type. /// Gets the !torch.Device type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
/// Gets the !torch.device typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Generator type. // torch.Generator type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -118,6 +138,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
/// Gets the !torch.Generator type. /// Gets the !torch.Generator type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
/// Gets the !torch.generator typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.bool type. // torch.bool type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -128,6 +151,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
/// Gets the !torch.bool type. /// Gets the !torch.bool type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
/// Gets the !torch.bool typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.int type. // torch.int type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -138,6 +164,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
/// Gets the !torch.int type. /// Gets the !torch.int type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
/// Gets the !torch.int typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.float type. // torch.float type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -148,6 +177,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
/// Gets the !torch.float type. /// Gets the !torch.float type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
/// Gets the !torch.float typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.LinearParams type. // torch.LinearParams type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -159,6 +191,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchLinearParamsTypeGet(MlirContext context); torchMlirTorchLinearParamsTypeGet(MlirContext context);
/// Gets the !torch.linearparams typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.qint8 type. // torch.qint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -169,6 +204,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type. /// Gets the !torch.qint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
/// Gets the !torch.qint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.quint8 type. // torch.quint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -179,6 +217,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
/// Gets the !torch.quint8 type. /// Gets the !torch.quint8 type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
/// Gets the !torch.quint8 typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -217,10 +258,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size /// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
/// indicates an unrefined/unknown size dimension. /// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); MLIR_CAPI_EXPORTED int64_t
torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.tensor. /// Gets the the dtype (data type) of a !torch.tensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
/// Gets the !torch.tensor typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.vtensor type. // torch.vtensor type.
@ -259,11 +305,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size /// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
/// indicates an unrefined/unknown size dimension. /// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); MLIR_CAPI_EXPORTED int64_t
torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.vtensor. /// Gets the the dtype (data type) of a !torch.vtensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
/// Gets the !torch.vtensor typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.none type. // !torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -274,6 +324,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
/// Gets the !torch.none type. /// Gets the !torch.none type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
/// Gets the !torch.none typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.str type. // !torch.str type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -284,6 +337,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
/// Gets the !torch.str type. /// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
/// Gets the !torch.str typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.any type. // !torch.any type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -294,6 +350,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
/// Gets the !torch.str type. /// Gets the !torch.str type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
/// Gets the !torch.any typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.number type. // !torch.number type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -304,6 +363,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
/// Gets the !torch.number type. /// Gets the !torch.number type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
/// Gets the !torch.number typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID();
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.dict type. // !torch.dict type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -324,6 +386,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
/// Gets the value type of a !torch.dict<key, value> type. /// Gets the value type of a !torch.dict<key, value> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
/// Gets the !torch.dict typeid.
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -34,6 +34,10 @@ MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
} }
MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
return wrap(Torch::NnModuleType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.optional type. // torch.optional type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -51,6 +55,10 @@ MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
return wrap(type.getContainedType()); return wrap(type.getContainedType());
} }
MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
return wrap(Torch::OptionalType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -63,9 +71,8 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
intptr_t numContainedTypes, intptr_t numContainedTypes,
MlirType const *containedTypes) { MlirType const *containedTypes) {
return wrap(Torch::TupleType::get( return wrap(Torch::TupleType::get(
unwrap(context), unwrap(context), llvm::to_vector<6>(llvm::map_range(
llvm::to_vector<6>( llvm::ArrayRef(containedTypes, numContainedTypes),
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
@ -79,6 +86,10 @@ MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
return wrap(type.getContainedTypes()[pos]); return wrap(type.getContainedTypes()[pos]);
} }
MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
return wrap(Torch::TupleType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type. // torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -91,9 +102,8 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
intptr_t numContainedTypes, intptr_t numContainedTypes,
MlirType const *containedTypes) { MlirType const *containedTypes) {
return wrap(Torch::UnionType::get( return wrap(Torch::UnionType::get(
unwrap(context), unwrap(context), llvm::to_vector<6>(llvm::map_range(
llvm::to_vector<6>( llvm::ArrayRef(containedTypes, numContainedTypes),
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
@ -107,6 +117,10 @@ MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
return wrap(type.getContainedTypes()[pos]); return wrap(type.getContainedTypes()[pos]);
} }
MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
return wrap(Torch::UnionType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.list<T> type. // torch.list<T> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -123,6 +137,10 @@ MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType()); return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
} }
MlirTypeID torchMlirTorchListTypeGetTypeID() {
return wrap(Torch::ListType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -135,6 +153,10 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context))); return wrap(Torch::DeviceType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
return wrap(Torch::DeviceType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Generator type. // torch.Generator type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -147,6 +169,10 @@ MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
return wrap(Torch::GeneratorType::get(unwrap(context))); return wrap(Torch::GeneratorType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
return wrap(Torch::GeneratorType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.bool type. // torch.bool type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -159,6 +185,10 @@ MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
return wrap(Torch::BoolType::get(unwrap(context))); return wrap(Torch::BoolType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
return wrap(Torch::BoolType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.int type. // torch.int type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -171,6 +201,10 @@ MlirType torchMlirTorchIntTypeGet(MlirContext context) {
return wrap(Torch::IntType::get(unwrap(context))); return wrap(Torch::IntType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchIntTypeGetTypeID() {
return wrap(Torch::IntType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.float type. // torch.float type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -183,6 +217,10 @@ MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
return wrap(Torch::FloatType::get(unwrap(context))); return wrap(Torch::FloatType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
return wrap(Torch::FloatType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.LinearParams type. // torch.LinearParams type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -195,6 +233,10 @@ MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context))); return wrap(Torch::LinearParamsType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
return wrap(Torch::LinearParamsType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.qint8 type. // torch.qint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -207,6 +249,10 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context))); return wrap(Torch::QInt8Type::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
return wrap(Torch::QInt8Type::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.quint8 type. // torch.quint8 type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -219,6 +265,10 @@ MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
return wrap(Torch::QUInt8Type::get(unwrap(context))); return wrap(Torch::QUInt8Type::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
return wrap(Torch::QUInt8Type::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tensor type. // torch.tensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -282,6 +332,10 @@ MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype()); return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
} }
MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
return wrap(Torch::NonValueTensorType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.vtensor type. // torch.vtensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -345,6 +399,10 @@ MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype()); return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
} }
MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
return wrap(Torch::ValueTensorType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.none type. // torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -357,6 +415,10 @@ MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context))); return wrap(Torch::NoneType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
return wrap(Torch::NoneType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.str type. // torch.str type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -369,6 +431,10 @@ MlirType torchMlirTorchStringTypeGet(MlirContext context) {
return wrap(Torch::StringType::get(unwrap(context))); return wrap(Torch::StringType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchStringTypeGetTypeID() {
return wrap(Torch::StringType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.any type. // torch.any type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -381,6 +447,10 @@ MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
return wrap(Torch::AnyType::get(unwrap(context))); return wrap(Torch::AnyType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
return wrap(Torch::AnyType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.number type. // torch.number type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -393,6 +463,10 @@ MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
return wrap(Torch::NumberType::get(unwrap(context))); return wrap(Torch::NumberType::get(unwrap(context)));
} }
MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
return wrap(Torch::NumberType::getTypeID());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Dict type. // torch.Dict type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -421,3 +495,7 @@ MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>(); auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getValueType()); return wrap(type.getValueType());
} }
MlirTypeID torchMlirTorchDictTypeGetTypeID() {
return wrap(Torch::DictType::getTypeID());
}