mirror of https://github.com/llvm/torch-mlir
Revert "Expose metadata of torch-mlir types (plus verify DictType key and value types). (#1785)"
This reverts commit 8696752eb6
.
revert-1785-expose_types
parent
2587b3f583
commit
745a616343
|
@ -190,7 +190,6 @@ add_custom_target(check-torch-mlir-all)
|
|||
add_dependencies(check-torch-mlir-all
|
||||
check-torch-mlir
|
||||
check-torch-mlir-dialects
|
||||
check-torch-mlir-capi
|
||||
)
|
||||
|
||||
if(MLIR_ENABLE_BINDINGS_PYTHON)
|
||||
|
|
|
@ -40,10 +40,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchOptionalTypeGet(MlirType containedType);
|
||||
|
||||
/// Gets the subtype T of !torch.optional<T> type.
|
||||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -56,12 +52,6 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes);
|
||||
|
||||
/// Returns the number of types contained in a !torch.tuple type.
|
||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
|
||||
|
||||
/// Returns the pos-th type in the !torch.tuple type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.union<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -74,12 +64,6 @@ MLIR_CAPI_EXPORTED MlirType
|
|||
torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
|
||||
MlirType const *containedTypes);
|
||||
|
||||
/// Returns the number of types contained in a !torch.union type.
|
||||
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
|
||||
|
||||
/// Returns the pos-th type in the !torch.union type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -90,9 +74,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
|
|||
/// Gets the !torch.list<T> type with contained T.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
|
||||
|
||||
/// Gets contained T in a !torch.list<T> type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -201,22 +182,6 @@ torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
|||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||
|
||||
/// Gets the the rank (number of dimensions) of a !torch.tensor
|
||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t);
|
||||
|
||||
/// Return true if this type has a list of sizes.
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t);
|
||||
|
||||
/// Return true if this type has a dtype.
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
|
||||
|
||||
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
|
||||
/// indicates an unrefined/unknown size dimension.
|
||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
||||
|
||||
/// Gets the the dtype (data type) of a !torch.tensor.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -243,22 +208,6 @@ torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
|||
MLIR_CAPI_EXPORTED MlirType
|
||||
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
|
||||
|
||||
/// Gets the the rank (number of dimensions) of a !torch.vtensor
|
||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t);
|
||||
|
||||
/// Return true if this type has a list of sizes.
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasSizes(MlirType t);
|
||||
|
||||
/// Return true if this type has a dtype.
|
||||
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
|
||||
|
||||
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
|
||||
/// indicates an unrefined/unknown size dimension.
|
||||
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
|
||||
|
||||
/// Gets the the dtype (data type) of a !torch.vtensor.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -310,15 +259,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
|
|||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
|
||||
MlirType valueType);
|
||||
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetChecked(
|
||||
MlirContext context, MlirType keyType, MlirType valueType);
|
||||
|
||||
/// Gets the key type of a !torch.dict<key, value> type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
|
||||
|
||||
/// Gets the value type of a !torch.dict<key, value> type.
|
||||
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
|
|
@ -347,7 +347,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
|
|||
Torch Dict type with key and value type.
|
||||
}];
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let genVerifyDecl = 1;
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
|
||||
"::mlir::Type":$valueType), [{
|
||||
|
@ -412,6 +411,7 @@ def AnyTorchDictKeyType : AnyTypeOf<[
|
|||
Torch_BoolType,
|
||||
Torch_FloatType,
|
||||
Torch_StringType,
|
||||
Torch_FloatType,
|
||||
AnyTorchTensorType,
|
||||
], "Allowed dict key types">;
|
||||
|
||||
|
|
|
@ -42,11 +42,6 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
|
|||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::OptionalType>();
|
||||
return wrap(type.getContainedType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tuple<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -65,16 +60,6 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
|
|||
[](MlirType t) { return unwrap(t); }))));
|
||||
}
|
||||
|
||||
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
||||
return type.getContainedTypes().size();
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
|
||||
auto type = unwrap(t).cast<Torch::TupleType>();
|
||||
return wrap(type.getContainedTypes()[pos]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.union<T1, T2, T3> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -93,16 +78,6 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
|
|||
[](MlirType t) { return unwrap(t); }))));
|
||||
}
|
||||
|
||||
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
||||
return type.getContainedTypes().size();
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
|
||||
auto type = unwrap(t).cast<Torch::UnionType>();
|
||||
return wrap(type.getContainedTypes()[pos]);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -115,10 +90,6 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
|
|||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -249,35 +220,6 @@ MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
|||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
|
||||
}
|
||||
|
||||
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
|
||||
}
|
||||
|
||||
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
|
||||
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
||||
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
|
||||
bool hasSizes = tensorType.hasSizes();
|
||||
if (!hasSizes)
|
||||
return -1;
|
||||
|
||||
auto sizes_ = tensorType.getSizes();
|
||||
for (const auto &s : llvm::enumerate(sizes_)) {
|
||||
sizes[s.index()] = s.value();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -312,35 +254,6 @@ MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
|
|||
attrTensorType.getElementType()));
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
|
||||
}
|
||||
|
||||
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
|
||||
}
|
||||
|
||||
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
|
||||
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
|
||||
}
|
||||
|
||||
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
|
||||
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
|
||||
bool hasSizes = tensorType.hasSizes();
|
||||
if (!hasSizes)
|
||||
return -1;
|
||||
|
||||
auto sizes_ = tensorType.getSizes();
|
||||
for (const auto &s : llvm::enumerate(sizes_)) {
|
||||
sizes[s.index()] = s.value();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
|
||||
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -400,20 +313,3 @@ bool torchMlirTypeIsATorchDict(MlirType t) {
|
|||
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
|
||||
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
|
||||
MlirType valueType) {
|
||||
auto unknownLoc = unwrap(mlirLocationUnknownGet(context));
|
||||
return wrap(Torch::DictType::getChecked(unknownLoc, unwrap(context),
|
||||
unwrap(keyType), unwrap(valueType)));
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::DictType>();
|
||||
return wrap(type.getKeyType());
|
||||
}
|
||||
|
||||
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
|
||||
auto type = unwrap(t).cast<Torch::DictType>();
|
||||
return wrap(type.getValueType());
|
||||
}
|
||||
|
|
|
@ -465,44 +465,3 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
|
|||
|
||||
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
|
||||
}
|
||||
|
||||
////===----------------------------------------------------------------------===//
|
||||
//// DictType
|
||||
////===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
|
||||
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
|
||||
// Unfortunately the generated implementations aren't visible/exposed ("static" linkage)
|
||||
// and the predicates themselves can't be added/used in the specification of the parameters
|
||||
// of the Torch_DictType.
|
||||
static bool isAnyTorchDictKeyType(Type type) {
|
||||
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
|
||||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
|
||||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>();
|
||||
}
|
||||
|
||||
static bool isAnyTorchType(Type type) {
|
||||
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
|
||||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() ||
|
||||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() ||
|
||||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() ||
|
||||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() ||
|
||||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() ||
|
||||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() ||
|
||||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() ||
|
||||
type.isa<Torch::UnionType>();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
|
||||
Type keyType, Type valueType) {
|
||||
if (!isAnyTorchDictKeyType(keyType)) {
|
||||
emitError() << "invalid " << keyType << " for !torch.dict key type";
|
||||
return failure();
|
||||
}
|
||||
if (!isAnyTorchType(valueType)) {
|
||||
emitError() << "invalid " << valueType << " for !torch.dict value type";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
add_llvm_executable(torch-mlir-capi-torch-test torch.c)
|
||||
llvm_update_compile_flags(torch-mlir-capi-torch-test)
|
||||
target_link_libraries(
|
||||
torch-mlir-capi-torch-test
|
||||
PRIVATE
|
||||
MLIRCAPIIR
|
||||
MLIRCAPIRegisterEverything
|
||||
TorchMLIRCAPI
|
||||
)
|
||||
|
||||
add_lit_testsuite(check-torch-mlir-capi "Running the torch-mlir CAPI tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS torch-mlir-capi-torch-test
|
||||
)
|
||||
set_target_properties(check-torch-mlir-capi PROPERTIES FOLDER "Tests")
|
|
@ -1 +0,0 @@
|
|||
config.suffixes.add('.c')
|
|
@ -1,186 +0,0 @@
|
|||
//===- torch.c - Test of Torch dialect C API ------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// RUN: torch-mlir-capi-torch-test 2>&1 | FileCheck %s
|
||||
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "torch-mlir-c/Registration.h"
|
||||
#include "torch-mlir-c/TorchTypes.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
|
||||
static void printToStderr(MlirStringRef str, void *userData) {
|
||||
(void)userData;
|
||||
fwrite(str.data, 1, str.length, stderr);
|
||||
}
|
||||
|
||||
static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes,
|
||||
MlirType dType, const char *testName) {
|
||||
#define DEFINE_CHECK(TTT) \
|
||||
MlirType TTT##Type = \
|
||||
torchMlirTorch##TTT##TypeGet(ctx, numSizes, sizes, dType); \
|
||||
\
|
||||
bool TTT##hasSizes = torchMlirTorch##TTT##TypeHasSizes(TTT##Type); \
|
||||
fprintf(stderr, #TTT "Type %s hasSizes: %d\n", testName, TTT##hasSizes); \
|
||||
bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \
|
||||
fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \
|
||||
if (TTT##hasSizes) { \
|
||||
fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \
|
||||
torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \
|
||||
int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \
|
||||
torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \
|
||||
for (int i = 0; i < numSizes; ++i) { \
|
||||
fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \
|
||||
TTT##Sizes[i]); \
|
||||
} \
|
||||
} \
|
||||
\
|
||||
if (TTT##hasDtype) { \
|
||||
MlirType TTT##Dtype = torchMlirTorch##TTT##TypeGetDtype(TTT##Type); \
|
||||
fprintf(stderr, #TTT "Type %s dtype: ", testName); \
|
||||
mlirTypePrint(TTT##Dtype, printToStderr, NULL); \
|
||||
fprintf(stderr, "\n"); \
|
||||
}
|
||||
DEFINE_CHECK(NonValueTensor)
|
||||
DEFINE_CHECK(ValueTensor)
|
||||
#undef DEFINE_CHECK
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testTypeMetaDataAccessors
|
||||
static void testTypeMetaDataAccessors(MlirContext ctx) {
|
||||
fprintf(stderr, "testTypeMetaDataAccessors\n");
|
||||
|
||||
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
|
||||
MlirType optionalI8 = torchMlirTorchOptionalTypeGet(i8);
|
||||
|
||||
fprintf(stderr, "optionalI8 isa TorchOptional: %d\n",
|
||||
torchMlirTypeIsATorchOptional(optionalI8));
|
||||
// CHECK: optionalI8 isa TorchOptional: 1
|
||||
|
||||
MlirType containedType = torchMlirTorchOptionalTypeGetContained(optionalI8);
|
||||
fprintf(stderr, "optionalI8 containedType: ");
|
||||
mlirTypePrint(containedType, printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
// CHECK: optionalI8 containedType: i8
|
||||
|
||||
MlirType f16 = mlirF16TypeGet(ctx);
|
||||
MlirType f32 = mlirF32TypeGet(ctx);
|
||||
MlirType _tupleI8[3] = {i8, f16, f32};
|
||||
#define DEFINE_CHECK(TTT) \
|
||||
MlirType TTT##I8 = torchMlirTorch##TTT##TypeGet(ctx, 3, _tupleI8); \
|
||||
\
|
||||
fprintf(stderr, #TTT "I8 isa " #TTT ": %d\n", \
|
||||
torchMlirTypeIsATorch##TTT(TTT##I8)); \
|
||||
\
|
||||
fprintf(stderr, #TTT "I8 NumTypes: %zu\n", \
|
||||
torchMlirTorch##TTT##TypeGetNumTypes(TTT##I8)); \
|
||||
\
|
||||
for (int i = 0; i < 3; ++i) { \
|
||||
fprintf(stderr, #TTT "I8 pos %d type: ", i); \
|
||||
mlirTypePrint(torchMlirTorch##TTT##TypeGetType(TTT##I8, i), printToStderr, \
|
||||
NULL); \
|
||||
fprintf(stderr, "\n"); \
|
||||
}
|
||||
DEFINE_CHECK(Tuple)
|
||||
DEFINE_CHECK(Union)
|
||||
#undef DEFINE_CHECK
|
||||
// CHECK: TupleI8 isa Tuple: 1
|
||||
// CHECK: TupleI8 NumTypes: 3
|
||||
// CHECK: TupleI8 pos 0 type: i8
|
||||
// CHECK: TupleI8 pos 1 type: f16
|
||||
// CHECK: TupleI8 pos 2 type: f32
|
||||
// CHECK: UnionI8 isa Union: 1
|
||||
// CHECK: UnionI8 NumTypes: 3
|
||||
// CHECK: UnionI8 pos 0 type: i8
|
||||
// CHECK: UnionI8 pos 1 type: f16
|
||||
// CHECK: UnionI8 pos 2 type: f32
|
||||
|
||||
int64_t sizes[3] = {1, 2, 3};
|
||||
testTensor(ctx, 3, sizes, f32, "has-sizes-dtype");
|
||||
// CHECK: NonValueTensorType has-sizes-dtype hasSizes: 1
|
||||
// CHECK: NonValueTensorType has-sizes-dtype hasDtype: 1
|
||||
// CHECK: NonValueTensorType has-sizes-dtype rank: 3
|
||||
// CHECK: NonValueTensorType has-sizes-dtype pos 0 size: 1
|
||||
// CHECK: NonValueTensorType has-sizes-dtype pos 1 size: 2
|
||||
// CHECK: NonValueTensorType has-sizes-dtype pos 2 size: 3
|
||||
// CHECK: NonValueTensorType has-sizes-dtype dtype: f32
|
||||
// CHECK: ValueTensorType has-sizes-dtype hasSizes: 1
|
||||
// CHECK: ValueTensorType has-sizes-dtype hasDtype: 1
|
||||
// CHECK: ValueTensorType has-sizes-dtype rank: 3
|
||||
// CHECK: ValueTensorType has-sizes-dtype pos 0 size: 1
|
||||
// CHECK: ValueTensorType has-sizes-dtype pos 1 size: 2
|
||||
// CHECK: ValueTensorType has-sizes-dtype pos 2 size: 3
|
||||
// CHECK: ValueTensorType has-sizes-dtype dtype: f32
|
||||
|
||||
MlirType nullType = {NULL};
|
||||
testTensor(ctx, 3, sizes, nullType, "has-sizes-no-dtype");
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype hasSizes: 1
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype hasDtype: 0
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype rank: 3
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 0 size: 1
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 1 size: 2
|
||||
// CHECK: NonValueTensorType has-sizes-no-dtype pos 2 size: 3
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype hasSizes: 1
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype hasDtype: 0
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype rank: 3
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 0 size: 1
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 1 size: 2
|
||||
// CHECK: ValueTensorType has-sizes-no-dtype pos 2 size: 3
|
||||
testTensor(ctx, -1, sizes, f32, "no-sizes-has-dtype");
|
||||
// CHECK: NonValueTensorType no-sizes-has-dtype hasSizes: 0
|
||||
// CHECK: NonValueTensorType no-sizes-has-dtype hasDtype: 1
|
||||
// CHECK: NonValueTensorType no-sizes-has-dtype dtype: f32
|
||||
// CHECK: ValueTensorType no-sizes-has-dtype hasSizes: 0
|
||||
// CHECK: ValueTensorType no-sizes-has-dtype hasDtype: 1
|
||||
// CHECK: ValueTensorType no-sizes-has-dtype dtype: f32
|
||||
|
||||
MlirType floatType = torchMlirTorchFloatTypeGet(ctx);
|
||||
torchMlirTorchDictTypeGetChecked(ctx, f16, floatType);
|
||||
// CHECK: error: invalid 'f16' for !torch.dict key type
|
||||
torchMlirTorchDictTypeGetChecked(ctx, i8, floatType);
|
||||
// CHECK: error: invalid 'i8' for !torch.dict key type
|
||||
torchMlirTorchDictTypeGetChecked(ctx, floatType, f16);
|
||||
// CHECK: error: invalid 'f16' for !torch.dict value type
|
||||
torchMlirTorchDictTypeGetChecked(ctx, floatType, i8);
|
||||
// CHECK: error: invalid 'i8' for !torch.dict value type
|
||||
|
||||
MlirType strType = torchMlirTorchStringTypeGet(ctx);
|
||||
|
||||
MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType);
|
||||
|
||||
fprintf(stderr, "dict keyType: ");
|
||||
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
// CHECK: dict keyType: !torch.str
|
||||
fprintf(stderr, "dict valueType: ");
|
||||
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
// CHECK: dict valueType: !torch.float
|
||||
|
||||
MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType);
|
||||
|
||||
fprintf(stderr, "dict keyType: ");
|
||||
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
// CHECK: dict keyType: !torch.float
|
||||
fprintf(stderr, "dict valueType: ");
|
||||
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL);
|
||||
fprintf(stderr, "\n");
|
||||
// CHECK: dict valueType: !torch.str
|
||||
}
|
||||
|
||||
int main(void) {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
torchMlirRegisterAllDialects(ctx);
|
||||
testTypeMetaDataAccessors(ctx);
|
||||
mlirContextDestroy(ctx);
|
||||
return EXIT_SUCCESS;
|
||||
}
|
|
@ -23,5 +23,3 @@ add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
|
|||
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
|
||||
|
||||
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
|
||||
|
||||
add_subdirectory(CAPI)
|
||||
|
|
|
@ -56,8 +56,6 @@ config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
|
|||
|
||||
# Tweak the PATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
|
||||
llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True)
|
||||
|
||||
# On Windows the path to python could contains spaces in which case it needs to
|
||||
# be provided in quotes. This is the equivalent of how %python is setup in
|
||||
|
|
|
@ -9,7 +9,6 @@ config.host_os = "@HOST_OS@"
|
|||
config.llvm_src_root = "@LLVM_SOURCE_DIR@"
|
||||
config.llvm_obj_root = "@LLVM_BINARY_DIR@"
|
||||
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
|
||||
config.llvm_build_dir = "@CMAKE_BINARY_DIR@"
|
||||
config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
|
||||
config.llvm_shlib_dir = "@SHLIBDIR@"
|
||||
config.llvm_shlib_ext = "@SHLIBEXT@"
|
||||
|
|
Loading…
Reference in New Issue