diff --git a/CMakeLists.txt b/CMakeLists.txt index c20627065..c1a96d44b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -190,7 +190,6 @@ add_custom_target(check-torch-mlir-all) add_dependencies(check-torch-mlir-all check-torch-mlir check-torch-mlir-dialects - check-torch-mlir-capi ) if(MLIR_ENABLE_BINDINGS_PYTHON) diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index 4b0143d19..e499c6744 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -40,10 +40,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchOptional(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchOptionalTypeGet(MlirType containedType); -/// Gets the subtype T of !torch.optional type. -MLIR_CAPI_EXPORTED MlirType -torchMlirTorchOptionalTypeGetContained(MlirType containedType); - //===----------------------------------------------------------------------===// // torch.tuple type. //===----------------------------------------------------------------------===// @@ -56,12 +52,6 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes); -/// Returns the number of types contained in a !torch.tuple type. -MLIR_CAPI_EXPORTED size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t); - -/// Returns the pos-th type in the !torch.tuple type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); - //===----------------------------------------------------------------------===// // torch.union type. //===----------------------------------------------------------------------===// @@ -74,12 +64,6 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGet(MlirContext context, intptr_t numContainedTypes, MlirType const *containedTypes); -/// Returns the number of types contained in a !torch.union type. -MLIR_CAPI_EXPORTED size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t); - -/// Returns the pos-th type in the !torch.union type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); - //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// @@ -90,9 +74,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchList(MlirType t); /// Gets the !torch.list type with contained T. MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); -/// Gets contained T in a !torch.list type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); - //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -201,22 +182,6 @@ torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation( MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr); -/// Gets the the rank (number of dimensions) of a !torch.tensor -MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t); - -/// Return true if this type has a list of sizes. -MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t); - -/// Return true if this type has a dtype. -MLIR_CAPI_EXPORTED bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t); - -/// Gets the the sizes of the dimensions of a !torch.tensor; note -1 size -/// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes); - -/// Gets the the dtype (data type) of a !torch.tensor. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); - //===----------------------------------------------------------------------===// // torch.vtensor type. //===----------------------------------------------------------------------===// @@ -243,22 +208,6 @@ torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr); -/// Gets the the rank (number of dimensions) of a !torch.vtensor -MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t); - -/// Return true if this type has a list of sizes. -MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasSizes(MlirType t); - -/// Return true if this type has a dtype. -MLIR_CAPI_EXPORTED bool torchMlirTorchValueTensorTypeHasDtype(MlirType t); - -/// Gets the the sizes of the dimensions of a !torch.vtensor; note -1 size -/// indicates an unrefined/unknown size dimension. -MLIR_CAPI_EXPORTED int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); - -/// Gets the the dtype (data type) of a !torch.vtensor. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); - //===----------------------------------------------------------------------===// // !torch.none type. //===----------------------------------------------------------------------===// @@ -310,15 +259,6 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDict(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType); -MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetChecked( - MlirContext context, MlirType keyType, MlirType valueType); - -/// Gets the key type of a !torch.dict type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t); - -/// Gets the value type of a !torch.dict type. -MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); - #ifdef __cplusplus } #endif diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index e168eaea2..f6f64d6ca 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -347,7 +347,6 @@ def Torch_DictType : Torch_Type<"Dict", "dict"> { Torch Dict type with key and value type. }]; let hasCustomAssemblyFormat = 1; - let genVerifyDecl = 1; let builders = [ TypeBuilderWithInferredContext<(ins "::mlir::Type":$keyType, "::mlir::Type":$valueType), [{ @@ -412,6 +411,7 @@ def AnyTorchDictKeyType : AnyTypeOf<[ Torch_BoolType, Torch_FloatType, Torch_StringType, + Torch_FloatType, AnyTorchTensorType, ], "Allowed dict key types">; diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp index 7609d89b4..4d50530f1 100644 --- a/lib/CAPI/TorchTypes.cpp +++ b/lib/CAPI/TorchTypes.cpp @@ -42,11 +42,6 @@ MlirType torchMlirTorchOptionalTypeGet(MlirType containedType) { return wrap(Torch::OptionalType::get(unwrap(containedType))); } -MlirType torchMlirTorchOptionalTypeGetContained(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getContainedType()); -} - //===----------------------------------------------------------------------===// // torch.tuple type. //===----------------------------------------------------------------------===// @@ -65,16 +60,6 @@ MlirType torchMlirTorchTupleTypeGet(MlirContext context, [](MlirType t) { return unwrap(t); })))); } -size_t torchMlirTorchTupleTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); - return type.getContainedTypes().size(); -} - -MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); - return wrap(type.getContainedTypes()[pos]); -} - //===----------------------------------------------------------------------===// // torch.union type. //===----------------------------------------------------------------------===// @@ -93,16 +78,6 @@ MlirType torchMlirTorchUnionTypeGet(MlirContext context, [](MlirType t) { return unwrap(t); })))); } -size_t torchMlirTorchUnionTypeGetNumTypes(MlirType t) { - auto type = unwrap(t).cast(); - return type.getContainedTypes().size(); -} - -MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos) { - auto type = unwrap(t).cast(); - return wrap(type.getContainedTypes()[pos]); -} - //===----------------------------------------------------------------------===// // torch.list type. //===----------------------------------------------------------------------===// @@ -115,10 +90,6 @@ MlirType torchMlirTorchListTypeGet(MlirType containedType) { return wrap(Torch::ListType::get(unwrap(containedType))); } -MlirType torchMlirTorchListTypeGetContainedType(MlirType t) { - return wrap(unwrap(t).cast().getContainedType()); -} - //===----------------------------------------------------------------------===// // torch.Device type. //===----------------------------------------------------------------------===// @@ -249,35 +220,6 @@ MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) { attrTensorType.getElementType())); } -int64_t torchMlirTorchNonValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); -} - -bool torchMlirTorchNonValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); -} - -bool torchMlirTorchNonValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); -} - -int64_t torchMlirTorchNonValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); - bool hasSizes = tensorType.hasSizes(); - if (!hasSizes) - return -1; - - auto sizes_ = tensorType.getSizes(); - for (const auto &s : llvm::enumerate(sizes_)) { - sizes[s.index()] = s.value(); - } - return 0; -} - -MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); -} - //===----------------------------------------------------------------------===// // torch.vtensor type. //===----------------------------------------------------------------------===// @@ -312,35 +254,6 @@ MlirType torchMlirTorchValueTensorTypeGetFromAttribute(MlirAttribute attr) { attrTensorType.getElementType())); } -int64_t torchMlirTorchValueTensorTypeGetRank(MlirType t) { - return unwrap(t).cast().getSizes().size(); -} - -bool torchMlirTorchValueTensorTypeHasSizes(MlirType t) { - return unwrap(t).cast().hasSizes(); -} - -bool torchMlirTorchValueTensorTypeHasDtype(MlirType t) { - return unwrap(t).cast().hasDtype(); -} - -int64_t torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes) { - auto tensorType = unwrap(t).cast(); - bool hasSizes = tensorType.hasSizes(); - if (!hasSizes) - return -1; - - auto sizes_ = tensorType.getSizes(); - for (const auto &s : llvm::enumerate(sizes_)) { - sizes[s.index()] = s.value(); - } - return 0; -} - -MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t) { - return wrap(unwrap(t).cast().getDtype()); -} - //===----------------------------------------------------------------------===// // torch.none type. //===----------------------------------------------------------------------===// @@ -400,20 +313,3 @@ bool torchMlirTypeIsATorchDict(MlirType t) { MlirType torchMlirTorchDictTypeGet(MlirType keyType, MlirType valueType) { return wrap(Torch::DictType::get(unwrap(keyType), unwrap(valueType))); } - -MlirType torchMlirTorchDictTypeGetChecked(MlirContext context, MlirType keyType, - MlirType valueType) { - auto unknownLoc = unwrap(mlirLocationUnknownGet(context)); - return wrap(Torch::DictType::getChecked(unknownLoc, unwrap(context), - unwrap(keyType), unwrap(valueType))); -} - -MlirType torchMlirTorchDictTypeGetKeyType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getKeyType()); -} - -MlirType torchMlirTorchDictTypeGetValueType(MlirType t) { - auto type = unwrap(t).cast(); - return wrap(type.getValueType()); -} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 968f809a0..f97804032 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -465,44 +465,3 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { return lhs.getWithSizesAndDtype(makeArrayRef(newSizes), dtype); } - -////===----------------------------------------------------------------------===// -//// DictType -////===----------------------------------------------------------------------===// - -// TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType -// and AnyTorchType generate the exact same code (in TorchOps.cpp.inc). -// Unfortunately the generated implementations aren't visible/exposed ("static" linkage) -// and the predicates themselves can't be added/used in the specification of the parameters -// of the Torch_DictType. -static bool isAnyTorchDictKeyType(Type type) { - return type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa(); -} - -static bool isAnyTorchType(Type type) { - return isValidSubtype(type, Torch::NumberType::get(type.getContext())) || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa() || type.isa() || - type.isa(); -} - -LogicalResult -DictType::verify(llvm::function_ref emitError, - Type keyType, Type valueType) { - if (!isAnyTorchDictKeyType(keyType)) { - emitError() << "invalid " << keyType << " for !torch.dict key type"; - return failure(); - } - if (!isAnyTorchType(valueType)) { - emitError() << "invalid " << valueType << " for !torch.dict value type"; - return failure(); - } - return success(); -} \ No newline at end of file diff --git a/test/CAPI/CMakeLists.txt b/test/CAPI/CMakeLists.txt deleted file mode 100644 index b48e3d24e..000000000 --- a/test/CAPI/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_llvm_executable(torch-mlir-capi-torch-test torch.c) -llvm_update_compile_flags(torch-mlir-capi-torch-test) -target_link_libraries( - torch-mlir-capi-torch-test - PRIVATE - MLIRCAPIIR - MLIRCAPIRegisterEverything - TorchMLIRCAPI -) - -add_lit_testsuite(check-torch-mlir-capi "Running the torch-mlir CAPI tests" - ${CMAKE_CURRENT_BINARY_DIR} - DEPENDS torch-mlir-capi-torch-test - ) -set_target_properties(check-torch-mlir-capi PROPERTIES FOLDER "Tests") \ No newline at end of file diff --git a/test/CAPI/lit.local.cfg b/test/CAPI/lit.local.cfg deleted file mode 100644 index 03902ac96..000000000 --- a/test/CAPI/lit.local.cfg +++ /dev/null @@ -1 +0,0 @@ -config.suffixes.add('.c') \ No newline at end of file diff --git a/test/CAPI/torch.c b/test/CAPI/torch.c deleted file mode 100644 index e9c5d23e2..000000000 --- a/test/CAPI/torch.c +++ /dev/null @@ -1,186 +0,0 @@ -//===- torch.c - Test of Torch dialect C API ------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -// RUN: torch-mlir-capi-torch-test 2>&1 | FileCheck %s - -#include "mlir-c/BuiltinTypes.h" -#include "torch-mlir-c/Registration.h" -#include "torch-mlir-c/TorchTypes.h" - -#include -#include -#include - -static void printToStderr(MlirStringRef str, void *userData) { - (void)userData; - fwrite(str.data, 1, str.length, stderr); -} - -static void testTensor(MlirContext ctx, intptr_t numSizes, int64_t *sizes, - MlirType dType, const char *testName) { -#define DEFINE_CHECK(TTT) \ - MlirType TTT##Type = \ - torchMlirTorch##TTT##TypeGet(ctx, numSizes, sizes, dType); \ - \ - bool TTT##hasSizes = torchMlirTorch##TTT##TypeHasSizes(TTT##Type); \ - fprintf(stderr, #TTT "Type %s hasSizes: %d\n", testName, TTT##hasSizes); \ - bool TTT##hasDtype = torchMlirTorch##TTT##TypeHasDtype(TTT##Type); \ - fprintf(stderr, #TTT "Type %s hasDtype: %d\n", testName, TTT##hasDtype); \ - if (TTT##hasSizes) { \ - fprintf(stderr, #TTT "Type %s rank: %zu\n", testName, \ - torchMlirTorch##TTT##TypeGetRank(TTT##Type)); \ - int64_t *TTT##Sizes = malloc(sizeof(int64_t) * numSizes); \ - torchMlirTorch##TTT##TypeGetSizes(TTT##Type, TTT##Sizes); \ - for (int i = 0; i < numSizes; ++i) { \ - fprintf(stderr, #TTT "Type %s pos %d size: %ld\n", testName, i, \ - TTT##Sizes[i]); \ - } \ - } \ - \ - if (TTT##hasDtype) { \ - MlirType TTT##Dtype = torchMlirTorch##TTT##TypeGetDtype(TTT##Type); \ - fprintf(stderr, #TTT "Type %s dtype: ", testName); \ - mlirTypePrint(TTT##Dtype, printToStderr, NULL); \ - fprintf(stderr, "\n"); \ - } - DEFINE_CHECK(NonValueTensor) - DEFINE_CHECK(ValueTensor) -#undef DEFINE_CHECK -} - -// CHECK-LABEL: testTypeMetaDataAccessors -static void testTypeMetaDataAccessors(MlirContext ctx) { - fprintf(stderr, "testTypeMetaDataAccessors\n"); - - MlirType i8 = mlirIntegerTypeGet(ctx, 8); - MlirType optionalI8 = torchMlirTorchOptionalTypeGet(i8); - - fprintf(stderr, "optionalI8 isa TorchOptional: %d\n", - torchMlirTypeIsATorchOptional(optionalI8)); - // CHECK: optionalI8 isa TorchOptional: 1 - - MlirType containedType = torchMlirTorchOptionalTypeGetContained(optionalI8); - fprintf(stderr, "optionalI8 containedType: "); - mlirTypePrint(containedType, printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: optionalI8 containedType: i8 - - MlirType f16 = mlirF16TypeGet(ctx); - MlirType f32 = mlirF32TypeGet(ctx); - MlirType _tupleI8[3] = {i8, f16, f32}; -#define DEFINE_CHECK(TTT) \ - MlirType TTT##I8 = torchMlirTorch##TTT##TypeGet(ctx, 3, _tupleI8); \ - \ - fprintf(stderr, #TTT "I8 isa " #TTT ": %d\n", \ - torchMlirTypeIsATorch##TTT(TTT##I8)); \ - \ - fprintf(stderr, #TTT "I8 NumTypes: %zu\n", \ - torchMlirTorch##TTT##TypeGetNumTypes(TTT##I8)); \ - \ - for (int i = 0; i < 3; ++i) { \ - fprintf(stderr, #TTT "I8 pos %d type: ", i); \ - mlirTypePrint(torchMlirTorch##TTT##TypeGetType(TTT##I8, i), printToStderr, \ - NULL); \ - fprintf(stderr, "\n"); \ - } - DEFINE_CHECK(Tuple) - DEFINE_CHECK(Union) -#undef DEFINE_CHECK - // CHECK: TupleI8 isa Tuple: 1 - // CHECK: TupleI8 NumTypes: 3 - // CHECK: TupleI8 pos 0 type: i8 - // CHECK: TupleI8 pos 1 type: f16 - // CHECK: TupleI8 pos 2 type: f32 - // CHECK: UnionI8 isa Union: 1 - // CHECK: UnionI8 NumTypes: 3 - // CHECK: UnionI8 pos 0 type: i8 - // CHECK: UnionI8 pos 1 type: f16 - // CHECK: UnionI8 pos 2 type: f32 - - int64_t sizes[3] = {1, 2, 3}; - testTensor(ctx, 3, sizes, f32, "has-sizes-dtype"); - // CHECK: NonValueTensorType has-sizes-dtype hasSizes: 1 - // CHECK: NonValueTensorType has-sizes-dtype hasDtype: 1 - // CHECK: NonValueTensorType has-sizes-dtype rank: 3 - // CHECK: NonValueTensorType has-sizes-dtype pos 0 size: 1 - // CHECK: NonValueTensorType has-sizes-dtype pos 1 size: 2 - // CHECK: NonValueTensorType has-sizes-dtype pos 2 size: 3 - // CHECK: NonValueTensorType has-sizes-dtype dtype: f32 - // CHECK: ValueTensorType has-sizes-dtype hasSizes: 1 - // CHECK: ValueTensorType has-sizes-dtype hasDtype: 1 - // CHECK: ValueTensorType has-sizes-dtype rank: 3 - // CHECK: ValueTensorType has-sizes-dtype pos 0 size: 1 - // CHECK: ValueTensorType has-sizes-dtype pos 1 size: 2 - // CHECK: ValueTensorType has-sizes-dtype pos 2 size: 3 - // CHECK: ValueTensorType has-sizes-dtype dtype: f32 - - MlirType nullType = {NULL}; - testTensor(ctx, 3, sizes, nullType, "has-sizes-no-dtype"); - // CHECK: NonValueTensorType has-sizes-no-dtype hasSizes: 1 - // CHECK: NonValueTensorType has-sizes-no-dtype hasDtype: 0 - // CHECK: NonValueTensorType has-sizes-no-dtype rank: 3 - // CHECK: NonValueTensorType has-sizes-no-dtype pos 0 size: 1 - // CHECK: NonValueTensorType has-sizes-no-dtype pos 1 size: 2 - // CHECK: NonValueTensorType has-sizes-no-dtype pos 2 size: 3 - // CHECK: ValueTensorType has-sizes-no-dtype hasSizes: 1 - // CHECK: ValueTensorType has-sizes-no-dtype hasDtype: 0 - // CHECK: ValueTensorType has-sizes-no-dtype rank: 3 - // CHECK: ValueTensorType has-sizes-no-dtype pos 0 size: 1 - // CHECK: ValueTensorType has-sizes-no-dtype pos 1 size: 2 - // CHECK: ValueTensorType has-sizes-no-dtype pos 2 size: 3 - testTensor(ctx, -1, sizes, f32, "no-sizes-has-dtype"); - // CHECK: NonValueTensorType no-sizes-has-dtype hasSizes: 0 - // CHECK: NonValueTensorType no-sizes-has-dtype hasDtype: 1 - // CHECK: NonValueTensorType no-sizes-has-dtype dtype: f32 - // CHECK: ValueTensorType no-sizes-has-dtype hasSizes: 0 - // CHECK: ValueTensorType no-sizes-has-dtype hasDtype: 1 - // CHECK: ValueTensorType no-sizes-has-dtype dtype: f32 - - MlirType floatType = torchMlirTorchFloatTypeGet(ctx); - torchMlirTorchDictTypeGetChecked(ctx, f16, floatType); - // CHECK: error: invalid 'f16' for !torch.dict key type - torchMlirTorchDictTypeGetChecked(ctx, i8, floatType); - // CHECK: error: invalid 'i8' for !torch.dict key type - torchMlirTorchDictTypeGetChecked(ctx, floatType, f16); - // CHECK: error: invalid 'f16' for !torch.dict value type - torchMlirTorchDictTypeGetChecked(ctx, floatType, i8); - // CHECK: error: invalid 'i8' for !torch.dict value type - - MlirType strType = torchMlirTorchStringTypeGet(ctx); - - MlirType dictType1 = torchMlirTorchDictTypeGet(strType, floatType); - - fprintf(stderr, "dict keyType: "); - mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType1), printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: dict keyType: !torch.str - fprintf(stderr, "dict valueType: "); - mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType1), printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: dict valueType: !torch.float - - MlirType dictType2 = torchMlirTorchDictTypeGet(floatType, strType); - - fprintf(stderr, "dict keyType: "); - mlirTypePrint(torchMlirTorchDictTypeGetKeyType(dictType2), printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: dict keyType: !torch.float - fprintf(stderr, "dict valueType: "); - mlirTypePrint(torchMlirTorchDictTypeGetValueType(dictType2), printToStderr, NULL); - fprintf(stderr, "\n"); - // CHECK: dict valueType: !torch.str -} - -int main(void) { - MlirContext ctx = mlirContextCreate(); - torchMlirRegisterAllDialects(ctx); - testTypeMetaDataAccessors(ctx); - mlirContextDestroy(ctx); - return EXIT_SUCCESS; -} diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8b444bd1d..a5098d8aa 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -23,5 +23,3 @@ add_lit_testsuite(check-torch-mlir "Running the torch-mlir regression tests" set_target_properties(check-torch-mlir PROPERTIES FOLDER "Tests") add_lit_testsuites(TORCH_MLIR ${CMAKE_CURRENT_SOURCE_DIR} DEPENDS ${TORCH_MLIR_TEST_DEPENDS}) - -add_subdirectory(CAPI) diff --git a/test/lit.cfg.py b/test/lit.cfg.py index a9753bf22..922980ff9 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -56,8 +56,6 @@ config.standalone_tools_dir = os.path.join(config.torch_mlir_obj_root, 'bin') # Tweak the PATH to include the tools dir. llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) -# Tweak the PATH to include the binary build dir, in order to pick up CAPI tests during out-of-tree. -llvm_config.with_environment('PATH', os.path.join(config.llvm_build_dir, 'bin'), append_path=True) # On Windows the path to python could contains spaces in which case it needs to # be provided in quotes. This is the equivalent of how %python is setup in diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 339975f3e..5dd0b2e52 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -9,7 +9,6 @@ config.host_os = "@HOST_OS@" config.llvm_src_root = "@LLVM_SOURCE_DIR@" config.llvm_obj_root = "@LLVM_BINARY_DIR@" config.llvm_tools_dir = "@LLVM_TOOLS_DIR@" -config.llvm_build_dir = "@CMAKE_BINARY_DIR@" config.llvm_lib_dir = "@LLVM_LIBS_DIR@" config.llvm_shlib_dir = "@SHLIBDIR@" config.llvm_shlib_ext = "@SHLIBEXT@"