torch-mlir/test/CAPI/ir.c

71 lines
2.0 KiB
C

/*===- ir.c - Simple test of C APIs ---------------------------------------===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
/* RUN: npcomp-capi-ir-test 2>&1 | FileCheck %s
*/
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "npcomp-c/Registration.h"
#include "npcomp-c/Types.h"
#include <assert.h>
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
// Dumps an instance of all NPComp types.
static int printStandardTypes(MlirContext ctx) {
// Bool type.
MlirType boolType = npcompBoolTypeGet(ctx);
if (!npcompTypeIsABool(boolType))
return 1;
mlirTypeDump(boolType);
fprintf(stderr, "\n");
// Any dtype.
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
if (!npcompTypeIsAAnyDtype(anyDtype))
return 2;
mlirTypeDump(anyDtype);
fprintf(stderr, "\n");
// Ranked NdArray.
int64_t fourDim = 4;
MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType);
if (!npcompTypeIsANdArray(rankedNdArray))
return 3;
mlirTypeDump(rankedNdArray);
fprintf(stderr, "\n");
return 0;
}
int main() {
MlirContext ctx = mlirContextCreate();
mlirRegisterAllDialects(ctx);
npcompRegisterAllDialects(ctx);
// clang-format off
// CHECK-LABEL: @types
// CHECK: !basicpy.BoolType
// CHECK: !numpy.any_dtype
// CHECK: !numpy.ndarray<[4]:!basicpy.BoolType>
// CHECK: 0
// clang-format on
fprintf(stderr, "@types\n");
int errcode = printStandardTypes(ctx);
fprintf(stderr, "%d\n", errcode);
mlirContextDestroy(ctx);
return 0;
}