Revert "Expose metadata of torch-mlir types (plus verify DictType key and value types). (#1785)"

This reverts commit 8696752eb6.
revert-1785-expose_types
Chi_Liu 2023-01-20 01:11:45 -08:00 committed by GitHub
parent 2587b3f583
commit 745a616343
11 changed files with 1 additions and 414 deletions

View File

@ -190,7 +190,6 @@ add_custom_target(check-torch-mlir-all)
add_dependencies(check-torch-mlir-all add_dependencies(check-torch-mlir-all
check-torch-mlir check-torch-mlir
check-torch-mlir-dialects check-torch-mlir-dialects
check-torch-mlir-capi
) )
if(MLIR_ENABLE_BINDINGS_PYTHON) if(MLIR_ENABLE_BINDINGS_PYTHON)

View File

@ -40,10 +40,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t);
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGet(MlirType containedType); torchMlirTorchOptionalTypeGet(MlirType containedType);
/// Gets the subtype T of !torch.optional<T> type.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchOptionalTypeGetContained(MlirType containedType);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -56,12 +52,6 @@ MLIR_CAPI_EXPORTED MlirType
torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes); MlirType const *containedTypes);
/// Returns the number of types contained in a !torch.tuple type.
MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.tuple type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type. // torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -74,12 +64,6 @@ MLIR_CAPI_EXPORTED MlirType
torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes,
MlirType const *containedTypes); MlirType const *containedTypes);
/// Returns the number of types contained in a !torch.union type.
MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t);
/// Returns the pos-th type in the !torch.union type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.list<T> type. // torch.list<T> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -90,9 +74,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t);
/// Gets the !torch.list<T> type with contained T. /// Gets the !torch.list<T> type with contained T.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType);
/// Gets contained T in a !torch.list<T> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -201,22 +182,6 @@ torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr); torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
/// Gets the the rank (number of dimensions) of a !torch.tensor
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t);
/// Return true if this type has a list of sizes.
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t);
/// Return true if this type has a dtype.
MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size
/// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.tensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.vtensor type. // torch.vtensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -243,22 +208,6 @@ torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
MLIR_CAPI_EXPORTED MlirType MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr); torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr);
/// Gets the the rank (number of dimensions) of a !torch.vtensor
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t);
/// Return true if this type has a list of sizes.
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasSizes(MlirType t);
/// Return true if this type has a dtype.
MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t);
/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size
/// indicates an unrefined/unknown size dimension.
MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes);
/// Gets the the dtype (data type) of a !torch.vtensor.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t);
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// !torch.none type. // !torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -310,15 +259,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType, MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType,
MlirType valueType); MlirType valueType);
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetChecked(
MlirContext context, MlirType keyType, MlirType valueType);
/// Gets the key type of a !torch.dict<key, value> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t);
/// Gets the value type of a !torch.dict<key, value> type.
MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -347,7 +347,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> {
Torch Dict type with key and value type. Torch Dict type with key and value type.
}]; }];
let hasCustomAssemblyFormat = 1; let hasCustomAssemblyFormat = 1;
let genVerifyDecl = 1;
let builders = [ let builders = [
TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType, TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType,
"::mlir::Type":$valueType), [{ "::mlir::Type":$valueType), [{
@ -412,6 +411,7 @@ def AnyTorchDictKeyType : AnyTypeOf<[
Torch_BoolType, Torch_BoolType,
Torch_FloatType, Torch_FloatType,
Torch_StringType, Torch_StringType,
Torch_FloatType,
AnyTorchTensorType, AnyTorchTensorType,
], "Allowed dict key types">; ], "Allowed dict key types">;

View File

@ -42,11 +42,6 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType))); return wrap(Torch::OptionalType::get(unwrap(containedType)));
} }
MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) {
auto type = unwrap(t).cast<Torch::OptionalType>();
return wrap(type.getContainedType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.tuple<T1, T2, T3> type. // torch.tuple<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -65,16 +60,6 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context,
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::TupleType>();
return type.getContainedTypes().size();
}
MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::TupleType>();
return wrap(type.getContainedTypes()[pos]);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.union<T1, T2, T3> type. // torch.union<T1, T2, T3> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -93,16 +78,6 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context,
[](MlirType t) { return unwrap(t); })))); [](MlirType t) { return unwrap(t); }))));
} }
size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) {
auto type = unwrap(t).cast<Torch::UnionType>();
return type.getContainedTypes().size();
}
MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) {
auto type = unwrap(t).cast<Torch::UnionType>();
return wrap(type.getContainedTypes()[pos]);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.list<T> type. // torch.list<T> type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -115,10 +90,6 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType))); return wrap(Torch::ListType::get(unwrap(containedType)));
} }
MlirType torchMlirTorchListTypeGetContainedType(MlirType t) {
return wrap(unwrap(t).cast<Torch::ListType>().getContainedType());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.Device type. // torch.Device type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -249,35 +220,6 @@ MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
attrTensorType.getElementType())); attrTensorType.getElementType()));
} }
int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().getSizes().size();
}
bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasSizes();
}
bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::NonValueTensorType>().hasDtype();
}
int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::NonValueTensorType>();
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
auto sizes_ = tensorType.getSizes();
for (const auto &s : llvm::enumerate(sizes_)) {
sizes[s.index()] = s.value();
}
return 0;
}
MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::NonValueTensorType>().getDtype());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.vtensor type. // torch.vtensor type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -312,35 +254,6 @@ MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) {
attrTensorType.getElementType())); attrTensorType.getElementType()));
} }
int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().getSizes().size();
}
bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasSizes();
}
bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) {
return unwrap(t).cast<Torch::ValueTensorType>().hasDtype();
}
int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) {
auto tensorType = unwrap(t).cast<Torch::ValueTensorType>();
bool hasSizes = tensorType.hasSizes();
if (!hasSizes)
return -1;
auto sizes_ = tensorType.getSizes();
for (const auto &s : llvm::enumerate(sizes_)) {
sizes[s.index()] = s.value();
}
return 0;
}
MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) {
return wrap(unwrap(t).cast<Torch::ValueTensorType>().getDtype());
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// torch.none type. // torch.none type.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -400,20 +313,3 @@ bool torchMlirTypeIsATorchDict(MlirType t) {
MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) { MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) {
return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType))); return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType)));
} }
MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType,
MlirType valueType) {
auto unknownLoc = unwrap(mlirLocationUnknownGet(context));
return wrap(Torch::DictType::getChecked(unknownLoc, unwrap(context),
unwrap(keyType), unwrap(valueType)));
}
MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getKeyType());
}
MlirType torchMlirTorchDictTypeGetValueType(MlirType t) {
auto type = unwrap(t).cast<Torch::DictType>();
return wrap(type.getValueType());
}

View File

@ -465,44 +465,3 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) {
return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype); return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype);
} }
////===----------------------------------------------------------------------===//
//// DictType
////===----------------------------------------------------------------------===//
// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType
// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc).
// Unfortunately the generated implementations aren't visible/exposed ("static" linkage)
// and the predicates themselves can't be added/used in the specification of the parameters
// of the Torch_DictType.
static bool isAnyTorchDictKeyType(Type type) {
return type.isa<Torch::AnyType>() || type.isa<Torch::IntType>() ||
type.isa<Torch::BoolType>() || type.isa<Torch::FloatType>() ||
type.isa<Torch::StringType>() || type.isa<Torch::BaseTensorType>();
}
static bool isAnyTorchType(Type type) {
return isValidSubtype(type, Torch::NumberType::get(type.getContext())) ||
type.isa<Torch::BaseTensorType>() || type.isa<Torch::AnyType>() ||
type.isa<Torch::BoolType>() || type.isa<Torch::DictType>() ||
type.isa<Torch::DeviceType>() || type.isa<Torch::GeneratorType>() ||
type.isa<Torch::ListType>() || type.isa<Torch::LinearParamsType>() ||
type.isa<Torch::NumberType>() || type.isa<Torch::NnModuleType>() ||
type.isa<Torch::NoneType>() || type.isa<Torch::OptionalType>() ||
type.isa<Torch::StringType>() || type.isa<Torch::TupleType>() ||
type.isa<Torch::UnionType>();
}
LogicalResult
DictType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
Type keyType, Type valueType) {
if (!isAnyTorchDictKeyType(keyType)) {
emitError() << "invalid " << keyType << " for !torch.dict key type";
return failure();
}
if (!isAnyTorchType(valueType)) {
emitError() << "invalid " << valueType << " for !torch.dict value type";
return failure();
}
return success();
}

View File

@ -1,15 +0,0 @@
add_llvm_executable(torch-mlir-capi-torch-test torch.c)
llvm_update_compile_flags(torch-mlir-capi-torch-test)
target_link_libraries(
torch-mlir-capi-torch-test
PRIVATE
MLIRCAPIIR
MLIRCAPIRegisterEverything
TorchMLIRCAPI
)
add_lit_testsuite(check-torch-mlir-capi "Running the torch-mlir CAPI tests"
${CMAKE_CURRENT_BINARY_DIR}
DEPENDS torch-mlir-capi-torch-test
)
set_target_properties(check-torch-mlir-capi PROPERTIES FOLDER "Tests")

View File

@ -1 +0,0 @@
config.suffixes.add('.c')

View File

@ -1,186 +0,0 @@
//===- torch.c - Test of Torch dialect C API ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// RUN: torch-mlir-capi-torch-test 2>&1 | FileCheck %s
#include "mlir-c/BuiltinTypes.h"
#include "torch-mlir-c/Registration.h"
#include "torch-mlir-c/TorchTypes.h"
#include <inttypes.h>
#include <stdio.h>
#include <stdlib.h>
static void printToStderr(MlirStringRef str, void *userData) {
(void)userData;
fwrite(str.data, 1, str.length, stderr);
}
static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes,
MlirType dType, const char *testName) {
#define DEFINE_CHECK(TTT) \
MlirType TTT##Type = \
torchMlirTorch##TTT##TypeGet(ctx, numSizes, sizes, dType); \
\
bool TTT##hasSizes = torchMlirTorch##TTT##TypeHasSizes(TTT##Type); \
fprintf(stderr, #TTT "Type %s hasSizes: %d\n", testName, TTT##hasSizes); \
bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \
fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \
if (TTT##hasSizes) { \
fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \
torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \
int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \
torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \
for (int i = 0; i < numSizes; ++i) { \
fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \
TTT##Sizes[i]); \
} \
} \
\
if (TTT##hasDtype) { \
MlirType TTT##Dtype = torchMlirTorch##TTT##TypeGetDtype(TTT##Type); \
fprintf(stderr, #TTT "Type %s dtype: ", testName); \
mlirTypePrint(TTT##Dtype, printToStderr, NULL); \
fprintf(stderr, "\n"); \
}
DEFINE_CHECK(NonValueTensor)
DEFINE_CHECK(ValueTensor)
#undef DEFINE_CHECK
}
// CHECK-LABEL: testTypeMetaDataAccessors
static void testTypeMetaDataAccessors(MlirContext ctx) {
fprintf(stderr, "testTypeMetaDataAccessors\n");
MlirType i8 = mlirIntegerTypeGet(ctx, 8);
MlirType optionalI8 = torchMlirTorchOptionalTypeGet(i8);
fprintf(stderr, "optionalI8 isa TorchOptional: %d\n",
torchMlirTypeIsATorchOptional(optionalI8));
// CHECK: optionalI8 isa TorchOptional: 1
MlirType containedType = torchMlirTorchOptionalTypeGetContained(optionalI8);
fprintf(stderr, "optionalI8 containedType: ");
mlirTypePrint(containedType, printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: optionalI8 containedType: i8
MlirType f16 = mlirF16TypeGet(ctx);
MlirType f32 = mlirF32TypeGet(ctx);
MlirType _tupleI8[3] = {i8, f16, f32};
#define DEFINE_CHECK(TTT) \
MlirType TTT##I8 = torchMlirTorch##TTT##TypeGet(ctx, 3, _tupleI8); \
\
fprintf(stderr, #TTT "I8 isa " #TTT ": %d\n", \
torchMlirTypeIsATorch##TTT(TTT##I8)); \
\
fprintf(stderr, #TTT "I8 NumTypes: %zu\n", \
torchMlirTorch##TTT##TypeGetNumTypes(TTT##I8)); \
\
for (int i = 0; i < 3; ++i) { \
fprintf(stderr, #TTT "I8 pos %d type: ", i); \
mlirTypePrint(torchMlirTorch##TTT##TypeGetType(TTT##I8, i), printToStderr, \
NULL); \
fprintf(stderr, "\n"); \
}
DEFINE_CHECK(Tuple)
DEFINE_CHECK(Union)
#undef DEFINE_CHECK
// CHECK: TupleI8 isa Tuple: 1
// CHECK: TupleI8 NumTypes: 3
// CHECK: TupleI8 pos 0 type: i8
// CHECK: TupleI8 pos 1 type: f16
// CHECK: TupleI8 pos 2 type: f32
// CHECK: UnionI8 isa Union: 1
// CHECK: UnionI8 NumTypes: 3
// CHECK: UnionI8 pos 0 type: i8
// CHECK: UnionI8 pos 1 type: f16
// CHECK: UnionI8 pos 2 type: f32
int64_t sizes[3] = {1, 2, 3};
testTensor(ctx, 3, sizes, f32, "has-sizes-dtype");
// CHECK: NonValueTensorType has-sizes-dtype hasSizes: 1
// CHECK: NonValueTensorType has-sizes-dtype hasDtype: 1
// CHECK: NonValueTensorType has-sizes-dtype rank: 3
// CHECK: NonValueTensorType has-sizes-dtype pos 0 size: 1
// CHECK: NonValueTensorType has-sizes-dtype pos 1 size: 2
// CHECK: NonValueTensorType has-sizes-dtype pos 2 size: 3
// CHECK: NonValueTensorType has-sizes-dtype dtype: f32
// CHECK: ValueTensorType has-sizes-dtype hasSizes: 1
// CHECK: ValueTensorType has-sizes-dtype hasDtype: 1
// CHECK: ValueTensorType has-sizes-dtype rank: 3
// CHECK: ValueTensorType has-sizes-dtype pos 0 size: 1
// CHECK: ValueTensorType has-sizes-dtype pos 1 size: 2
// CHECK: ValueTensorType has-sizes-dtype pos 2 size: 3
// CHECK: ValueTensorType has-sizes-dtype dtype: f32
MlirType nullType = {NULL};
testTensor(ctx, 3, sizes, nullType, "has-sizes-no-dtype");
// CHECK: NonValueTensorType has-sizes-no-dtype hasSizes: 1
// CHECK: NonValueTensorType has-sizes-no-dtype hasDtype: 0
// CHECK: NonValueTensorType has-sizes-no-dtype rank: 3
// CHECK: NonValueTensorType has-sizes-no-dtype pos 0 size: 1
// CHECK: NonValueTensorType has-sizes-no-dtype pos 1 size: 2
// CHECK: NonValueTensorType has-sizes-no-dtype pos 2 size: 3
// CHECK: ValueTensorType has-sizes-no-dtype hasSizes: 1
// CHECK: ValueTensorType has-sizes-no-dtype hasDtype: 0
// CHECK: ValueTensorType has-sizes-no-dtype rank: 3
// CHECK: ValueTensorType has-sizes-no-dtype pos 0 size: 1
// CHECK: ValueTensorType has-sizes-no-dtype pos 1 size: 2
// CHECK: ValueTensorType has-sizes-no-dtype pos 2 size: 3
testTensor(ctx, -1, sizes, f32, "no-sizes-has-dtype");
// CHECK: NonValueTensorType no-sizes-has-dtype hasSizes: 0
// CHECK: NonValueTensorType no-sizes-has-dtype hasDtype: 1
// CHECK: NonValueTensorType no-sizes-has-dtype dtype: f32
// CHECK: ValueTensorType no-sizes-has-dtype hasSizes: 0
// CHECK: ValueTensorType no-sizes-has-dtype hasDtype: 1
// CHECK: ValueTensorType no-sizes-has-dtype dtype: f32
MlirType floatType = torchMlirTorchFloatTypeGet(ctx);
torchMlirTorchDictTypeGetChecked(ctx, f16, floatType);
// CHECK: error: invalid 'f16' for !torch.dict key type
torchMlirTorchDictTypeGetChecked(ctx, i8, floatType);
// CHECK: error: invalid 'i8' for !torch.dict key type
torchMlirTorchDictTypeGetChecked(ctx, floatType, f16);
// CHECK: error: invalid 'f16' for !torch.dict value type
torchMlirTorchDictTypeGetChecked(ctx, floatType, i8);
// CHECK: error: invalid 'i8' for !torch.dict value type
MlirType strType = torchMlirTorchStringTypeGet(ctx);
MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType);
fprintf(stderr, "dict keyType: ");
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: dict keyType: !torch.str
fprintf(stderr, "dict valueType: ");
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: dict valueType: !torch.float
MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType);
fprintf(stderr, "dict keyType: ");
mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: dict keyType: !torch.float
fprintf(stderr, "dict valueType: ");
mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL);
fprintf(stderr, "\n");
// CHECK: dict valueType: !torch.str
}
int main(void) {
MlirContext ctx = mlirContextCreate();
torchMlirRegisterAllDialects(ctx);
testTypeMetaDataAccessors(ctx);
mlirContextDestroy(ctx);
return EXIT_SUCCESS;
}

View File

@ -23,5 +23,3 @@ add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests"
set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests") set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests")
add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS}) add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS})
add_subdirectory(CAPI)

View File

@ -56,8 +56,6 @@ config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin')
# Tweak the PATH to include the tools dir. # Tweak the PATH to include the tools dir.
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree.
llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True)
# On Windows the path to python could contains spaces in which case it needs to # On Windows the path to python could contains spaces in which case it needs to
# be provided in quotes. This is the equivalent of how %python is setup in # be provided in quotes. This is the equivalent of how %python is setup in

View File

@ -9,7 +9,6 @@ config.host_os = "@HOST_OS@"
config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_src_root = "@LLVM_SOURCE_DIR@"
config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@"
config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@"
config.llvm_build_dir = "@CMAKE_BINARY_DIR@"
config.llvm_lib_dir = "@LLVM_LIBS_DIR@" config.llvm_lib_dir = "@LLVM_LIBS_DIR@"
config.llvm_shlib_dir = "@SHLIBDIR@" config.llvm_shlib_dir = "@SHLIBDIR@"
config.llvm_shlib_ext = "@SHLIBEXT@" config.llvm_shlib_ext = "@SHLIBEXT@"