mirror of https://github.com/llvm/torch-mlir
Add typeids to CAPI. (#2253)
parent
ebda611100
commit
0244f540a7
|
@ -34,6 +34,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNnModule(MlirType t);
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||||
|
|
||||||
|
/// Gets the !torch.nn.Module typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.optional type.
|
// torch.optional type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -49,6 +52,9 @@ torchMlirTorchOptionalTypeGet(MlirType containedType);
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
|
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
|
||||||
|
|
||||||
|
/// Gets the !torch.optional typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -65,7 +71,11 @@ torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
|
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
|
||||||
|
|
||||||
/// Returns the pos-th type in the !torch.tuple type.
|
/// Returns the pos-th type in the !torch.tuple type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t,
|
||||||
|
intptr_t pos);
|
||||||
|
|
||||||
|
/// Gets the !torch.tuple typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.union<T1, T2, T3> type.
|
// torch.union<T1, T2, T3> type.
|
||||||
|
@ -83,7 +93,11 @@ torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
|
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
|
||||||
|
|
||||||
/// Returns the pos-th type in the !torch.union type.
|
/// Returns the pos-th type in the !torch.union type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t,
|
||||||
|
intptr_t pos);
|
||||||
|
|
||||||
|
/// Gets the !torch.union typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.list<T> type.
|
// torch.list<T> type.
|
||||||
|
@ -98,6 +112,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
||||||
/// Gets contained T in a !torch.list<T> type.
|
/// Gets contained T in a !torch.list<T> type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.list typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -108,6 +125,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t);
|
||||||
/// Gets the !torch.Device type.
|
/// Gets the !torch.Device type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.device typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Generator type.
|
// torch.Generator type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -118,6 +138,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t);
|
||||||
/// Gets the !torch.Generator type.
|
/// Gets the !torch.Generator type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.generator typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.bool type.
|
// torch.bool type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -128,6 +151,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t);
|
||||||
/// Gets the !torch.bool type.
|
/// Gets the !torch.bool type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.bool typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.int type.
|
// torch.int type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -138,6 +164,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t);
|
||||||
/// Gets the !torch.int type.
|
/// Gets the !torch.int type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.int typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.float type.
|
// torch.float type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -148,6 +177,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t);
|
||||||
/// Gets the !torch.float type.
|
/// Gets the !torch.float type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.float typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.LinearParams type.
|
// torch.LinearParams type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -159,6 +191,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchLinearParams(MlirType t);
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchLinearParamsTypeGet(MlirContext context);
|
torchMlirTorchLinearParamsTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.linearparams typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.qint8 type.
|
// torch.qint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -169,6 +204,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t);
|
||||||
/// Gets the !torch.qint8 type.
|
/// Gets the !torch.qint8 type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.qint8 typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.quint8 type.
|
// torch.quint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -179,6 +217,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t);
|
||||||
/// Gets the !torch.quint8 type.
|
/// Gets the !torch.quint8 type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.quint8 typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -217,10 +258,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
|
||||||
|
|
||||||
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
|
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
|
||||||
/// indicates an unrefined/unknown size dimension.
|
/// indicates an unrefined/unknown size dimension.
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
MLIR_CAPI_EXPORTED int64_t
|
||||||
|
torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
||||||
|
|
||||||
/// Gets the the dtype (data type) of a !torch.tensor.
|
/// Gets the the dtype (data type) of a !torch.tensor.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
|
torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.tensor typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
|
@ -259,11 +305,15 @@ MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
|
||||||
|
|
||||||
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
|
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
|
||||||
/// indicates an unrefined/unknown size dimension.
|
/// indicates an unrefined/unknown size dimension.
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
MLIR_CAPI_EXPORTED int64_t
|
||||||
|
torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
||||||
|
|
||||||
/// Gets the the dtype (data type) of a !torch.vtensor.
|
/// Gets the the dtype (data type) of a !torch.vtensor.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.vtensor typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.none type.
|
// !torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -274,6 +324,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t);
|
||||||
/// Gets the !torch.none type.
|
/// Gets the !torch.none type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.none typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.str type.
|
// !torch.str type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -284,6 +337,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t);
|
||||||
/// Gets the !torch.str type.
|
/// Gets the !torch.str type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.str typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.any type.
|
// !torch.any type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -294,6 +350,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t);
|
||||||
/// Gets the !torch.str type.
|
/// Gets the !torch.str type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.any typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.number type.
|
// !torch.number type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -304,6 +363,9 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t);
|
||||||
/// Gets the !torch.number type.
|
/// Gets the !torch.number type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets the !torch.number typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID();
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.dict type.
|
// !torch.dict type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -324,6 +386,9 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
|
||||||
/// Gets the value type of a !torch.dict<key, value> type.
|
/// Gets the value type of a !torch.dict<key, value> type.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.dict typeid.
|
||||||
|
MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID();
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -34,6 +34,10 @@ MlirType torchMlirTorchNnModuleTypeGet(MlirContext context,
|
||||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchNnModuleTypeGetTypeID() {
|
||||||
|
return wrap(Torch::NnModuleType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.optional type.
|
// torch.optional type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -51,6 +55,10 @@ MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
|
||||||
return wrap(type.getContainedType());
|
return wrap(type.getContainedType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchOptionalTypeGetTypeID() {
|
||||||
|
return wrap(Torch::OptionalType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -63,9 +71,8 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||||
intptr_t numContainedTypes,
|
intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes) {
|
MlirType const *containedTypes) {
|
||||||
return wrap(Torch::TupleType::get(
|
return wrap(Torch::TupleType::get(
|
||||||
unwrap(context),
|
unwrap(context), llvm::to_vector<6>(llvm::map_range(
|
||||||
llvm::to_vector<6>(
|
llvm::ArrayRef(containedTypes, numContainedTypes),
|
||||||
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
|
|
||||||
[](MlirType t) { return unwrap(t); }))));
|
[](MlirType t) { return unwrap(t); }))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -79,6 +86,10 @@ MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
|
||||||
return wrap(type.getContainedTypes()[pos]);
|
return wrap(type.getContainedTypes()[pos]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchTupleTypeGetTypeID() {
|
||||||
|
return wrap(Torch::TupleType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.union<T1, T2, T3> type.
|
// torch.union<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -91,9 +102,8 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
||||||
intptr_t numContainedTypes,
|
intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes) {
|
MlirType const *containedTypes) {
|
||||||
return wrap(Torch::UnionType::get(
|
return wrap(Torch::UnionType::get(
|
||||||
unwrap(context),
|
unwrap(context), llvm::to_vector<6>(llvm::map_range(
|
||||||
llvm::to_vector<6>(
|
llvm::ArrayRef(containedTypes, numContainedTypes),
|
||||||
llvm::map_range(llvm::ArrayRef(containedTypes, numContainedTypes),
|
|
||||||
[](MlirType t) { return unwrap(t); }))));
|
[](MlirType t) { return unwrap(t); }))));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -107,6 +117,10 @@ MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
|
||||||
return wrap(type.getContainedTypes()[pos]);
|
return wrap(type.getContainedTypes()[pos]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchUnionTypeGetTypeID() {
|
||||||
|
return wrap(Torch::UnionType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.list<T> type.
|
// torch.list<T> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -123,6 +137,10 @@ MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
|
||||||
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
|
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchListTypeGetTypeID() {
|
||||||
|
return wrap(Torch::ListType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -135,6 +153,10 @@ MlirType torchMlirTorchDeviceTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchDeviceTypeGetTypeID() {
|
||||||
|
return wrap(Torch::DeviceType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Generator type.
|
// torch.Generator type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -147,6 +169,10 @@ MlirType torchMlirTorchGeneratorTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::GeneratorType::get(unwrap(context)));
|
return wrap(Torch::GeneratorType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchGeneratorTypeGetTypeID() {
|
||||||
|
return wrap(Torch::GeneratorType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.bool type.
|
// torch.bool type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -159,6 +185,10 @@ MlirType torchMlirTorchBoolTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::BoolType::get(unwrap(context)));
|
return wrap(Torch::BoolType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchBoolTypeGetTypeID() {
|
||||||
|
return wrap(Torch::BoolType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.int type.
|
// torch.int type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -171,6 +201,10 @@ MlirType torchMlirTorchIntTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::IntType::get(unwrap(context)));
|
return wrap(Torch::IntType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchIntTypeGetTypeID() {
|
||||||
|
return wrap(Torch::IntType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.float type.
|
// torch.float type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -183,6 +217,10 @@ MlirType torchMlirTorchFloatTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::FloatType::get(unwrap(context)));
|
return wrap(Torch::FloatType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchFloatTypeGetTypeID() {
|
||||||
|
return wrap(Torch::FloatType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.LinearParams type.
|
// torch.LinearParams type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -195,6 +233,10 @@ MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID() {
|
||||||
|
return wrap(Torch::LinearParamsType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.qint8 type.
|
// torch.qint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -207,6 +249,10 @@ MlirType torchMlirTorchQInt8TypeGet(MlirContext context) {
|
||||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchQInt8TypeGetTypeID() {
|
||||||
|
return wrap(Torch::QInt8Type::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.quint8 type.
|
// torch.quint8 type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -219,6 +265,10 @@ MlirType torchMlirTorchQUInt8TypeGet(MlirContext context) {
|
||||||
return wrap(Torch::QUInt8Type::get(unwrap(context)));
|
return wrap(Torch::QUInt8Type::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchQUInt8TypeGetTypeID() {
|
||||||
|
return wrap(Torch::QUInt8Type::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tensor type.
|
// torch.tensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -282,6 +332,10 @@ MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
|
||||||
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
|
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID() {
|
||||||
|
return wrap(Torch::NonValueTensorType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -345,6 +399,10 @@ MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
|
||||||
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
|
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchValueTensorTypeGetTypeID() {
|
||||||
|
return wrap(Torch::ValueTensorType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.none type.
|
// torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -357,6 +415,10 @@ MlirType torchMlirTorchNoneTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchNoneTypeGetTypeID() {
|
||||||
|
return wrap(Torch::NoneType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.str type.
|
// torch.str type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -369,6 +431,10 @@ MlirType torchMlirTorchStringTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::StringType::get(unwrap(context)));
|
return wrap(Torch::StringType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchStringTypeGetTypeID() {
|
||||||
|
return wrap(Torch::StringType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.any type.
|
// torch.any type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -381,6 +447,10 @@ MlirType torchMlirTorchAnyTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::AnyType::get(unwrap(context)));
|
return wrap(Torch::AnyType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchAnyTypeGetTypeID() {
|
||||||
|
return wrap(Torch::AnyType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.number type.
|
// torch.number type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -393,6 +463,10 @@ MlirType torchMlirTorchNumberTypeGet(MlirContext context) {
|
||||||
return wrap(Torch::NumberType::get(unwrap(context)));
|
return wrap(Torch::NumberType::get(unwrap(context)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchNumberTypeGetTypeID() {
|
||||||
|
return wrap(Torch::NumberType::getTypeID());
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Dict type.
|
// torch.Dict type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -421,3 +495,7 @@ MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
|
||||||
auto type = unwrap(t).cast<Torch::DictType>();
|
auto type = unwrap(t).cast<Torch::DictType>();
|
||||||
return wrap(type.getValueType());
|
return wrap(type.getValueType());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MlirTypeID torchMlirTorchDictTypeGetTypeID() {
|
||||||
|
return wrap(Torch::DictType::getTypeID());
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue