mirror of https://github.com/llvm/torch-mlir
Revert "Expose metadata of torch-mlir types (plus verify DictType key and value types). (#1785)"
This reverts commit 8696752eb6
.
revert-1785-expose_types
parent
2587b3f583
commit
745a616343
|
@ -190,7 +190,6 @@ add_custom_target(check-torch-mlir-all)
|
||||||
add_dependencies(check-torch-mlir-all
|
add_dependencies(check-torch-mlir-all
|
||||||
check-torch-mlir
|
check-torch-mlir
|
||||||
check-torch-mlir-dialects
|
check-torch-mlir-dialects
|
||||||
check-torch-mlir-capi
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||||
|
|
|
@ -40,10 +40,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchOptionalTypeGet(MlirType containedType);
|
torchMlirTorchOptionalTypeGet(MlirType containedType);
|
||||||
|
|
||||||
/// Gets the subtype T of !torch.optional<T> type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
|
||||||
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -56,12 +52,6 @@ MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes);
|
MlirType const *containedTypes);
|
||||||
|
|
||||||
/// Returns the number of types contained in a !torch.tuple type.
|
|
||||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
|
|
||||||
|
|
||||||
/// Returns the pos-th type in the !torch.tuple type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.union<T1, T2, T3> type.
|
// torch.union<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -74,12 +64,6 @@ MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
|
torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||||
MlirType const *containedTypes);
|
MlirType const *containedTypes);
|
||||||
|
|
||||||
/// Returns the number of types contained in a !torch.union type.
|
|
||||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
|
|
||||||
|
|
||||||
/// Returns the pos-th type in the !torch.union type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.list<T> type.
|
// torch.list<T> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -90,9 +74,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
|
||||||
/// Gets the !torch.list<T> type with contained T.
|
/// Gets the !torch.list<T> type with contained T.
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
||||||
|
|
||||||
/// Gets contained T in a !torch.list<T> type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -201,22 +182,6 @@ torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||||
|
|
||||||
/// Gets the the rank (number of dimensions) of a !torch.tensor
|
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t);
|
|
||||||
|
|
||||||
/// Return true if this type has a list of sizes.
|
|
||||||
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t);
|
|
||||||
|
|
||||||
/// Return true if this type has a dtype.
|
|
||||||
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
|
|
||||||
|
|
||||||
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
|
|
||||||
/// indicates an unrefined/unknown size dimension.
|
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
|
||||||
|
|
||||||
/// Gets the the dtype (data type) of a !torch.tensor.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -243,22 +208,6 @@ torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||||
MLIR_CAPI_EXPORTED MlirType
|
MLIR_CAPI_EXPORTED MlirType
|
||||||
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||||
|
|
||||||
/// Gets the the rank (number of dimensions) of a !torch.vtensor
|
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t);
|
|
||||||
|
|
||||||
/// Return true if this type has a list of sizes.
|
|
||||||
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasSizes(MlirType t);
|
|
||||||
|
|
||||||
/// Return true if this type has a dtype.
|
|
||||||
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
|
|
||||||
|
|
||||||
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
|
|
||||||
/// indicates an unrefined/unknown size dimension.
|
|
||||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
|
||||||
|
|
||||||
/// Gets the the dtype (data type) of a !torch.vtensor.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// !torch.none type.
|
// !torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -310,15 +259,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
|
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
|
||||||
MlirType valueType);
|
MlirType valueType);
|
||||||
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetChecked(
|
|
||||||
MlirContext context, MlirType keyType, MlirType valueType);
|
|
||||||
|
|
||||||
/// Gets the key type of a !torch.dict<key, value> type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
|
|
||||||
|
|
||||||
/// Gets the value type of a !torch.dict<key, value> type.
|
|
||||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -347,7 +347,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
||||||
Torch Dict type with key and value type.
|
Torch Dict type with key and value type.
|
||||||
}];
|
}];
|
||||||
let hasCustomAssemblyFormat = 1;
|
let hasCustomAssemblyFormat = 1;
|
||||||
let genVerifyDecl = 1;
|
|
||||||
let builders = [
|
let builders = [
|
||||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||||
"::mlir::Type":$valueType), [{
|
"::mlir::Type":$valueType), [{
|
||||||
|
@ -412,6 +411,7 @@ def AnyTorchDictKeyType : AnyTypeOf<[
|
||||||
Torch_BoolType,
|
Torch_BoolType,
|
||||||
Torch_FloatType,
|
Torch_FloatType,
|
||||||
Torch_StringType,
|
Torch_StringType,
|
||||||
|
Torch_FloatType,
|
||||||
AnyTorchTensorType,
|
AnyTorchTensorType,
|
||||||
], "Allowed dict key types">;
|
], "Allowed dict key types">;
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,6 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
||||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
|
|
||||||
auto type = unwrap(t).cast<Torch::OptionalType>();
|
|
||||||
return wrap(type.getContainedType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.tuple<T1, T2, T3> type.
|
// torch.tuple<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -65,16 +60,6 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
||||||
[](MlirType t) { return unwrap(t); }))));
|
[](MlirType t) { return unwrap(t); }))));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
|
|
||||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
|
||||||
return type.getContainedTypes().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
|
|
||||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
|
||||||
return wrap(type.getContainedTypes()[pos]);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.union<T1, T2, T3> type.
|
// torch.union<T1, T2, T3> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -93,16 +78,6 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
||||||
[](MlirType t) { return unwrap(t); }))));
|
[](MlirType t) { return unwrap(t); }))));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
|
|
||||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
|
||||||
return type.getContainedTypes().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
|
|
||||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
|
||||||
return wrap(type.getContainedTypes()[pos]);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.list<T> type.
|
// torch.list<T> type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -115,10 +90,6 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
||||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
|
|
||||||
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.Device type.
|
// torch.Device type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -249,35 +220,6 @@ MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||||
attrTensorType.getElementType()));
|
attrTensorType.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
|
||||||
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
|
|
||||||
bool hasSizes = tensorType.hasSizes();
|
|
||||||
if (!hasSizes)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
auto sizes_ = tensorType.getSizes();
|
|
||||||
for (const auto &s : llvm::enumerate(sizes_)) {
|
|
||||||
sizes[s.index()] = s.value();
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
|
|
||||||
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.vtensor type.
|
// torch.vtensor type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -312,35 +254,6 @@ MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
||||||
attrTensorType.getElementType()));
|
attrTensorType.getElementType()));
|
||||||
}
|
}
|
||||||
|
|
||||||
int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
|
|
||||||
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
|
|
||||||
}
|
|
||||||
|
|
||||||
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
|
||||||
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
|
|
||||||
bool hasSizes = tensorType.hasSizes();
|
|
||||||
if (!hasSizes)
|
|
||||||
return -1;
|
|
||||||
|
|
||||||
auto sizes_ = tensorType.getSizes();
|
|
||||||
for (const auto &s : llvm::enumerate(sizes_)) {
|
|
||||||
sizes[s.index()] = s.value();
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
|
|
||||||
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// torch.none type.
|
// torch.none type.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -400,20 +313,3 @@ bool torchMlirTypeIsATorchDict(MlirType t) {
|
||||||
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||||
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
|
|
||||||
MlirType valueType) {
|
|
||||||
auto unknownLoc = unwrap(mlirLocationUnknownGet(context));
|
|
||||||
return wrap(Torch::DictType::getChecked(unknownLoc, unwrap(context),
|
|
||||||
unwrap(keyType), unwrap(valueType)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
|
|
||||||
auto type = unwrap(t).cast<Torch::DictType>();
|
|
||||||
return wrap(type.getKeyType());
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
|
|
||||||
auto type = unwrap(t).cast<Torch::DictType>();
|
|
||||||
return wrap(type.getValueType());
|
|
||||||
}
|
|
||||||
|
|
|
@ -465,44 +465,3 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
||||||
|
|
||||||
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
|
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
////===----------------------------------------------------------------------===//
|
|
||||||
//// DictType
|
|
||||||
////===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
|
|
||||||
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
|
|
||||||
// Unfortunately the generated implementations aren't visible/exposed ("static" linkage)
|
|
||||||
// and the predicates themselves can't be added/used in the specification of the parameters
|
|
||||||
// of the Torch_DictType.
|
|
||||||
static bool isAnyTorchDictKeyType(Type type) {
|
|
||||||
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
|
|
||||||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
|
|
||||||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool isAnyTorchType(Type type) {
|
|
||||||
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
|
|
||||||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() ||
|
|
||||||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() ||
|
|
||||||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() ||
|
|
||||||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() ||
|
|
||||||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() ||
|
|
||||||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() ||
|
|
||||||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() ||
|
|
||||||
type.isa<Torch::UnionType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
LogicalResult
|
|
||||||
DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
|
||||||
Type keyType, Type valueType) {
|
|
||||||
if (!isAnyTorchDictKeyType(keyType)) {
|
|
||||||
emitError() << "invalid " << keyType << " for !torch.dict key type";
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
if (!isAnyTorchType(valueType)) {
|
|
||||||
emitError() << "invalid " << valueType << " for !torch.dict value type";
|
|
||||||
return failure();
|
|
||||||
}
|
|
||||||
return success();
|
|
||||||
}
|
|
|
@ -1,15 +0,0 @@
|
||||||
add_llvm_executable(torch-mlir-capi-torch-test torch.c)
|
|
||||||
llvm_update_compile_flags(torch-mlir-capi-torch-test)
|
|
||||||
target_link_libraries(
|
|
||||||
torch-mlir-capi-torch-test
|
|
||||||
PRIVATE
|
|
||||||
MLIRCAPIIR
|
|
||||||
MLIRCAPIRegisterEverything
|
|
||||||
TorchMLIRCAPI
|
|
||||||
)
|
|
||||||
|
|
||||||
add_lit_testsuite(check-torch-mlir-capi "Running the torch-mlir CAPI tests"
|
|
||||||
${CMAKE_CURRENT_BINARY_DIR}
|
|
||||||
DEPENDS torch-mlir-capi-torch-test
|
|
||||||
)
|
|
||||||
set_target_properties(check-torch-mlir-capi PROPERTIES FOLDER "Tests")
|
|
|
@ -1 +0,0 @@
|
||||||
config.suffixes.add('.c')
|
|
|
@ -1,186 +0,0 @@
|
||||||
//===- torch.c - Test of Torch dialect C API ------------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
|
||||||
// Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
// RUN: torch-mlir-capi-torch-test 2>&1 | FileCheck %s
|
|
||||||
|
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
|
||||||
#include "torch-mlir-c/Registration.h"
|
|
||||||
#include "torch-mlir-c/TorchTypes.h"
|
|
||||||
|
|
||||||
#include <inttypes.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
#include <stdlib.h>
|
|
||||||
|
|
||||||
static void printToStderr(MlirStringRef str, void *userData) {
|
|
||||||
(void)userData;
|
|
||||||
fwrite(str.data, 1, str.length, stderr);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes,
|
|
||||||
MlirType dType, const char *testName) {
|
|
||||||
#define DEFINE_CHECK(TTT) \
|
|
||||||
MlirType TTT##Type = \
|
|
||||||
torchMlirTorch##TTT##TypeGet(ctx, numSizes, sizes, dType); \
|
|
||||||
\
|
|
||||||
bool TTT##hasSizes = torchMlirTorch##TTT##TypeHasSizes(TTT##Type); \
|
|
||||||
fprintf(stderr, #TTT "Type %s hasSizes: %d\n", testName, TTT##hasSizes); \
|
|
||||||
bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \
|
|
||||||
fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \
|
|
||||||
if (TTT##hasSizes) { \
|
|
||||||
fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \
|
|
||||||
torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \
|
|
||||||
int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \
|
|
||||||
torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \
|
|
||||||
for (int i = 0; i < numSizes; ++i) { \
|
|
||||||
fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \
|
|
||||||
TTT##Sizes[i]); \
|
|
||||||
} \
|
|
||||||
} \
|
|
||||||
\
|
|
||||||
if (TTT##hasDtype) { \
|
|
||||||
MlirType TTT##Dtype = torchMlirTorch##TTT##TypeGetDtype(TTT##Type); \
|
|
||||||
fprintf(stderr, #TTT "Type %s dtype: ", testName); \
|
|
||||||
mlirTypePrint(TTT##Dtype, printToStderr, NULL); \
|
|
||||||
fprintf(stderr, "\n"); \
|
|
||||||
}
|
|
||||||
DEFINE_CHECK(NonValueTensor)
|
|
||||||
DEFINE_CHECK(ValueTensor)
|
|
||||||
#undef DEFINE_CHECK
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: testTypeMetaDataAccessors
|
|
||||||
static void testTypeMetaDataAccessors(MlirContext ctx) {
|
|
||||||
fprintf(stderr, "testTypeMetaDataAccessors\n");
|
|
||||||
|
|
||||||
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
|
|
||||||
MlirType optionalI8 = torchMlirTorchOptionalTypeGet(i8);
|
|
||||||
|
|
||||||
fprintf(stderr, "optionalI8 isa TorchOptional: %d\n",
|
|
||||||
torchMlirTypeIsATorchOptional(optionalI8));
|
|
||||||
// CHECK: optionalI8 isa TorchOptional: 1
|
|
||||||
|
|
||||||
MlirType containedType = torchMlirTorchOptionalTypeGetContained(optionalI8);
|
|
||||||
fprintf(stderr, "optionalI8 containedType: ");
|
|
||||||
mlirTypePrint(containedType, printToStderr, NULL);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
// CHECK: optionalI8 containedType: i8
|
|
||||||
|
|
||||||
MlirType f16 = mlirF16TypeGet(ctx);
|
|
||||||
MlirType f32 = mlirF32TypeGet(ctx);
|
|
||||||
MlirType _tupleI8[3] = {i8, f16, f32};
|
|
||||||
#define DEFINE_CHECK(TTT) \
|
|
||||||
MlirType TTT##I8 = torchMlirTorch##TTT##TypeGet(ctx, 3, _tupleI8); \
|
|
||||||
\
|
|
||||||
fprintf(stderr, #TTT "I8 isa " #TTT ": %d\n", \
|
|
||||||
torchMlirTypeIsATorch##TTT(TTT##I8)); \
|
|
||||||
\
|
|
||||||
fprintf(stderr, #TTT "I8 NumTypes: %zu\n", \
|
|
||||||
torchMlirTorch##TTT##TypeGetNumTypes(TTT##I8)); \
|
|
||||||
\
|
|
||||||
for (int i = 0; i < 3; ++i) { \
|
|
||||||
fprintf(stderr, #TTT "I8 pos %d type: ", i); \
|
|
||||||
mlirTypePrint(torchMlirTorch##TTT##TypeGetType(TTT##I8, i), printToStderr, \
|
|
||||||
NULL); \
|
|
||||||
fprintf(stderr, "\n"); \
|
|
||||||
}
|
|
||||||
DEFINE_CHECK(Tuple)
|
|
||||||
DEFINE_CHECK(Union)
|
|
||||||
#undef DEFINE_CHECK
|
|
||||||
// CHECK: TupleI8 isa Tuple: 1
|
|
||||||
// CHECK: TupleI8 NumTypes: 3
|
|
||||||
// CHECK: TupleI8 pos 0 type: i8
|
|
||||||
// CHECK: TupleI8 pos 1 type: f16
|
|
||||||
// CHECK: TupleI8 pos 2 type: f32
|
|
||||||
// CHECK: UnionI8 isa Union: 1
|
|
||||||
// CHECK: UnionI8 NumTypes: 3
|
|
||||||
// CHECK: UnionI8 pos 0 type: i8
|
|
||||||
// CHECK: UnionI8 pos 1 type: f16
|
|
||||||
// CHECK: UnionI8 pos 2 type: f32
|
|
||||||
|
|
||||||
int64_t sizes[3] = {1, 2, 3};
|
|
||||||
testTensor(ctx, 3, sizes, f32, "has-sizes-dtype");
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype hasSizes: 1
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype hasDtype: 1
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype rank: 3
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype pos 0 size: 1
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype pos 1 size: 2
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype pos 2 size: 3
|
|
||||||
// CHECK: NonValueTensorType has-sizes-dtype dtype: f32
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype hasSizes: 1
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype hasDtype: 1
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype rank: 3
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype pos 0 size: 1
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype pos 1 size: 2
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype pos 2 size: 3
|
|
||||||
// CHECK: ValueTensorType has-sizes-dtype dtype: f32
|
|
||||||
|
|
||||||
MlirType nullType = {NULL};
|
|
||||||
testTensor(ctx, 3, sizes, nullType, "has-sizes-no-dtype");
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype hasSizes: 1
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype hasDtype: 0
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype rank: 3
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 0 size: 1
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 1 size: 2
|
|
||||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 2 size: 3
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype hasSizes: 1
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype hasDtype: 0
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype rank: 3
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 0 size: 1
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 1 size: 2
|
|
||||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 2 size: 3
|
|
||||||
testTensor(ctx, -1, sizes, f32, "no-sizes-has-dtype");
|
|
||||||
// CHECK: NonValueTensorType no-sizes-has-dtype hasSizes: 0
|
|
||||||
// CHECK: NonValueTensorType no-sizes-has-dtype hasDtype: 1
|
|
||||||
// CHECK: NonValueTensorType no-sizes-has-dtype dtype: f32
|
|
||||||
// CHECK: ValueTensorType no-sizes-has-dtype hasSizes: 0
|
|
||||||
// CHECK: ValueTensorType no-sizes-has-dtype hasDtype: 1
|
|
||||||
// CHECK: ValueTensorType no-sizes-has-dtype dtype: f32
|
|
||||||
|
|
||||||
MlirType floatType = torchMlirTorchFloatTypeGet(ctx);
|
|
||||||
torchMlirTorchDictTypeGetChecked(ctx, f16, floatType);
|
|
||||||
// CHECK: error: invalid 'f16' for !torch.dict key type
|
|
||||||
torchMlirTorchDictTypeGetChecked(ctx, i8, floatType);
|
|
||||||
// CHECK: error: invalid 'i8' for !torch.dict key type
|
|
||||||
torchMlirTorchDictTypeGetChecked(ctx, floatType, f16);
|
|
||||||
// CHECK: error: invalid 'f16' for !torch.dict value type
|
|
||||||
torchMlirTorchDictTypeGetChecked(ctx, floatType, i8);
|
|
||||||
// CHECK: error: invalid 'i8' for !torch.dict value type
|
|
||||||
|
|
||||||
MlirType strType = torchMlirTorchStringTypeGet(ctx);
|
|
||||||
|
|
||||||
MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType);
|
|
||||||
|
|
||||||
fprintf(stderr, "dict keyType: ");
|
|
||||||
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
// CHECK: dict keyType: !torch.str
|
|
||||||
fprintf(stderr, "dict valueType: ");
|
|
||||||
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
// CHECK: dict valueType: !torch.float
|
|
||||||
|
|
||||||
MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType);
|
|
||||||
|
|
||||||
fprintf(stderr, "dict keyType: ");
|
|
||||||
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
// CHECK: dict keyType: !torch.float
|
|
||||||
fprintf(stderr, "dict valueType: ");
|
|
||||||
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL);
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
// CHECK: dict valueType: !torch.str
|
|
||||||
}
|
|
||||||
|
|
||||||
int main(void) {
|
|
||||||
MlirContext ctx = mlirContextCreate();
|
|
||||||
torchMlirRegisterAllDialects(ctx);
|
|
||||||
testTypeMetaDataAccessors(ctx);
|
|
||||||
mlirContextDestroy(ctx);
|
|
||||||
return EXIT_SUCCESS;
|
|
||||||
}
|
|
|
@ -23,5 +23,3 @@ add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
|
||||||
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
|
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
|
||||||
|
|
||||||
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
||||||
|
|
||||||
add_subdirectory(CAPI)
|
|
||||||
|
|
|
@ -56,8 +56,6 @@ config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
||||||
|
|
||||||
# Tweak the PATH to include the tools dir.
|
# Tweak the PATH to include the tools dir.
|
||||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||||
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
|
|
||||||
llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True)
|
|
||||||
|
|
||||||
# On Windows the path to python could contains spaces in which case it needs to
|
# On Windows the path to python could contains spaces in which case it needs to
|
||||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||||
|
|
|
@ -9,7 +9,6 @@ config.host_os = "@HOST_OS@"
|
||||||
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||||
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||||
config.llvm_build_dir = "@CMAKE_BINARY_DIR@"
|
|
||||||
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||||
config.llvm_shlib_dir = "@SHLIBDIR@"
|
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||||
config.llvm_shlib_ext = "@SHLIBEXT@"
|
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||||
|
|
Loading…
Reference in New Issue