Make C API files more consistent

- Make consistent with MLIR Core
  - Use `//` or `///` comments.
  - Use `bool` type for booleans
  - No duplicated comments in .cpp files
- Split types into separate files `{Basicpy,Numpy,Torch}Types.h`
- Add dialect prefix consistently to C API symbols. We have lots of
  similarly named types (e.g. "list" type in basicpy and torch).
pull/222/head
Sean Silva 2021-06-14 14:13:59 -07:00
parent db282fd1b4
commit 6b2424512b
17 changed files with 704 additions and 619 deletions

View File

@ -11,7 +11,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/TorchTypes.h"
#include "npcomp/Python/PybindUtils.h"
#include <ATen/core/function_schema.h>
@ -512,9 +512,8 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
}
if (ival.isList()) {
return npcompListTypeGet(
typeMapper.mapFromTorchType(
loc, ival.toList().elementType()));
return npcompTorchListTypeGet(
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
}
if (ival.isNone()) {
return npcompTorchNoneTypeGet(funcBuilder->getContext());
@ -530,7 +529,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd(
funcBuilder->getEntryBlock(), "torch.tensor", loc,
npcompNonValueTensorTypeGetFromShaped(
npcompTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);

View File

@ -12,7 +12,7 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/TorchTypes.h"
using namespace torch_mlir;
@ -132,7 +132,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
std::vector<MlirValue> &elements) {
MlirType resultType = npcompListTypeGet(elementType);
MlirType resultType = npcompTorchListTypeGet(elementType);
OperationStateHolder state{"torch.prim.ListConstruct", loc};
mlirOperationStateAddResults(state, 1, &resultType);
mlirOperationStateAddOperands(state, elements.size(), elements.data());

View File

@ -16,7 +16,8 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/TorchTypes.h"
#include "caffe2/core/scope_guard.h"
#include "ATen/native/quantized/cpu/packed_params.h"
@ -170,7 +171,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
MlirOperation nnModule = createMlirOperation(
"torch.nn_module", loc,
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
mlirRegionCreate());
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
@ -240,7 +241,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
MlirLocation loc = mlirLocationUnknownGet(context);
if (ivalue.isBool()) {
MlirType type = npcompBoolTypeGet(context);
MlirType type = npcompBasicpyBoolTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bool_constant", loc, type,
toMlirNamedAttribute("value",
@ -269,11 +270,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
for (const c10::IValue &elem : list) {
elems.push_back(importIValue(elem));
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc,
npcompListTypeGet(
typeMapper.mapFromTorchType(
loc, list.elementType())), elems);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.prim.ListConstruct", loc,
npcompTorchListTypeGet(
typeMapper.mapFromTorchType(loc, list.elementType())),
elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTuple()) {
@ -284,7 +285,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation =
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
npcompTupleTypeGet(context), elems);
npcompBasicpyTupleTypeGet(context), elems);
return mlirOperationGetResult(operation, 0);
}
if (ivalue.isTensor()) {
@ -294,7 +295,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
return importModule(ivalue.toModule());
}
if (ivalue.isString()) {
MlirType type = npcompBytesTypeGet(context);
MlirType type = npcompBasicpyBytesTypeGet(context);
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "basicpy.bytes_constant", loc, type,
toMlirNamedAttribute(
@ -325,7 +326,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
}
MlirOperation operation = createMlirOperationAtEnd(
importBlock, "torch.linear_params.create", loc,
npcompLinearParamsTypeGet(context), weightValue, biasValue);
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
return mlirOperationGetResult(operation, 0);
}
}
@ -345,7 +346,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp =
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
npcompNonValueTensorTypeGetFromShaped(
npcompTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
@ -360,7 +361,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// compiler stages that are building a statically modeled quantization
// representation will need to convert this to their representation.
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
MlirType quantizedTensorType = npcompNonValueTensorTypeGet(
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
context, shape.size(), shape.data(),
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
if (tensor.qscheme() == c10::kPerTensorAffine) {
@ -504,11 +505,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
mlirLocationUnknownGet(context), *maybeDtype);
MlirType typeBound;
if (hasValueSemantics) {
typeBound = npcompValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
} else {
typeBound = npcompNonValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
shape.data(), dtype);
}
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(

View File

@ -15,7 +15,8 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/TorchTypes.h"
namespace py = pybind11;
using namespace torch_mlir;
@ -145,7 +146,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
loc, appendToBlock),
mlirRegionCreate());
mapResults(node, operation);
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
std::vector<MlirType> terminatorOperandTypes = {
npcompBasicpyBoolTypeGet(context)};
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
resultTypes.begin(), resultTypes.end());
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,

View File

@ -10,7 +10,8 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/TorchTypes.h"
using namespace torch_mlir;
@ -23,14 +24,14 @@ MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
return createMlirOperation(
"basicpy.bool_constant", loc, npcompBoolTypeGet(context),
"basicpy.bool_constant", loc, npcompBasicpyBoolTypeGet(context),
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
}
MlirOperation OpBuilder::createBytesConstant(MlirLocation loc,
const std::string &value) {
return createMlirOperation(
"basicpy.bytes_constant", loc, npcompBytesTypeGet(context),
"basicpy.bytes_constant", loc, npcompBasicpyBytesTypeGet(context),
toMlirNamedAttribute("value",
mlirStringAttrGet(context, toMlirStringRef(value))));
}

View File

@ -15,7 +15,8 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/Types.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/TorchTypes.h"
using namespace torch_mlir;
@ -54,7 +55,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
case ScalarType::Long:
return mlirIntegerTypeSignedGet(context, 64);
case ScalarType::Bool:
return npcompBoolTypeGet(context);
return npcompBasicpyBoolTypeGet(context);
case ScalarType::Double:
return mlirF64TypeGet(context);
case ScalarType::Float:
@ -64,7 +65,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
case ScalarType::Half:
return mlirF16TypeGet(context);
case ScalarType::QInt8:
return npcompQInt8TypeGet(context);
return npcompTorchQInt8TypeGet(context);
default: {
return {nullptr};
}
@ -103,7 +104,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
// Individually handle the custom classes that we know about.
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
return npcompLinearParamsTypeGet(context);
return npcompTorchLinearParamsTypeGet(context);
}
// At this point, we know that the type is indeed a custom class type, but
@ -134,11 +135,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto &sizes = tensorType->symbolic_sizes();
if (!sizes.rank()) {
// Unranked.
return npcompNonValueTensorTypeGet(context,
/*numSizes=*/0,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);
return npcompTorchNonValueTensorTypeGet(context,
/*numSizes=*/0,
/*optionalSizes=*/nullptr,
/*optionalDtype=*/
elementType);
}
// Ranked with possibly dynamic dims.
auto &symbolicShape = tensorType->symbolic_sizes();
@ -148,10 +149,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
auto shapeSymbol = symbolicShape[i];
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
}
return npcompNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(),
/*optionalDtype=*/
elementType);
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
/*optionalSizes=*/dims.data(),
/*optionalDtype=*/
elementType);
}
case TypeKind::ClassType: {
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
@ -161,15 +162,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
}
auto maybeName = classType->name();
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
return npcompNnModuleTypeGet(context, toMlirStringRef(name));
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
}
case TypeKind::FloatType: {
return mlirF64TypeGet(context);
}
case TypeKind::OptionalType: {
return npcompOptionalTypeGet(
mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
return npcompTorchOptionalTypeGet(mapFromTorchType(
loc, torchType->cast<c10::OptionalType>()->getElementType()));
}
case TypeKind::IntType: {
return mlirIntegerTypeGet(context, 64);
@ -178,22 +178,21 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
return npcompTorchNoneTypeGet(context);
}
case TypeKind::BoolType: {
return npcompBoolTypeGet(context);
return npcompBasicpyBoolTypeGet(context);
}
case TypeKind::ListType: {
return npcompListTypeGet(
mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
return npcompTorchListTypeGet(mapFromTorchType(
loc, torchType->cast<c10::ListType>()->getElementType()));
}
case TypeKind::TupleType: {
// TODO: Don't lose the element type information.
return npcompTupleTypeGet(context);
return npcompBasicpyTupleTypeGet(context);
}
case TypeKind::StringType: {
return npcompBytesTypeGet(context);
return npcompBasicpyBytesTypeGet(context);
}
case TypeKind::DeviceObjType: {
return npcompDeviceTypeGet(context);
return npcompTorchDeviceTypeGet(context);
}
default: {
std::stringstream message;
@ -216,8 +215,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
// just erase them and let the compiler decide.
auto sizes = tensor.sizes();
return npcompNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
elementType);
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
elementType);
}
MlirType

View File

@ -0,0 +1,92 @@
//===-- npcomp-c/BasicpyTypes.h - C API for basicpy types ---------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_BASICPYTYPES_H
#define NPCOMP_C_BASICPYTYPES_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
//===----------------------------------------------------------------------===//
// !basicpy.BoolType
//===----------------------------------------------------------------------===//
/// Checks whether the given type is the Python "bool" type.
bool npcompTypeIsABasicpyBool(MlirType t);
/// Gets the Python "bool" type.
MlirType npcompBasicpyBoolTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !basicpy.BytesType
//===----------------------------------------------------------------------===//
/// Checks whether the given type is the Python "bytes" type.
bool npcompTypeIsABasicpyBytes(MlirType t);
/// Gets the Python "bytes" type.
MlirType npcompBasicpyBytesTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !basicpy.DictType
//===----------------------------------------------------------------------===//
/// Checks whether the given type is the Python "dict" type.
bool npcompTypeIsABasicpyDict(MlirType t);
/// Gets the generic Python "dict" type.
MlirType npcompBasicpyDictTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// List type
//===----------------------------------------------------------------------===//
/// Checks whether the given type is the Python "list" type.
bool npcompTypeIsABasicpyList(MlirType t);
/// Gets the generic Python "list" type.
MlirType npcompBasicpyListTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// !basicpy.NoneType type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a `!basicpy.NoneType`.
bool npcompTypeIsABasicpyNone(MlirType t);
/// Gets the `!basicpy.NoneType` type.
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// SlotObject type.
//===----------------------------------------------------------------------===//
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes);
//===----------------------------------------------------------------------===//
// !basicpy.TupleType
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a `!basicpy.TupleType`.
bool npcompTypeIsABasicpyTuple(MlirType t);
/// Gets the generic Python "tuple" type.
MlirType npcompBasicpyTupleTypeGet(MlirContext context);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_BASICPYTYPES_H

View File

@ -0,0 +1,55 @@
//===-- npcomp-c/NumpyTypes.h - C API for numpy types -------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_NUMPYTYPES_H
#define NPCOMP_C_NUMPYTYPES_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
//===----------------------------------------------------------------------===//
// !numpy.any_dtype type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is the special "any dtype" type that is used
// to signal an NDArray or tensor of unknown type.
bool npcompTypeIsANumpyAnyDtype(MlirType t);
/// Gets the "any dtype" type.
MlirType npcompAnyDtypeTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// NDArray type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is an NdArray type.
bool npcompTypeIsANumpyNdArray(MlirType t);
/// Gets a numpy.NdArray type that is unranked.
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType);
/// Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
/// unknown.
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType);
/// Helper that gets an equivalent NdArrayType from a ShapedType.
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType);
/// Helper that converts an NdArrayType to a TensorType.
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_NUMPYTYPES_H

View File

@ -0,0 +1,143 @@
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_TORCHTYPES_H
#define NPCOMP_C_TORCHTYPES_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
//===----------------------------------------------------------------------===//
// torch.nn.Module type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a torch.nn.Module type
bool npcompTypeIsATorchNnModule(MlirType t);
/// Gets the !torch.nn.Module type of the specified class.
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className);
//===----------------------------------------------------------------------===//
// torch.optional type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.optional<T> type
bool npcompTypeIsATorchOptional(MlirType t);
/// Gets the !torch.optional<T> type with subtype T.
MlirType npcompTorchOptionalTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.list<T> type
bool npcompTypeIsATorchList(MlirType t);
/// Gets the !torch.list<T> type with contained T.
MlirType npcompTorchListTypeGet(MlirType containedType);
//===----------------------------------------------------------------------===//
// torch.Device type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.Device type
bool npcompTypeIsATorchDevice(MlirType t);
/// Gets the !torch.Device type.
MlirType npcompTorchDeviceTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.LinearParams type
bool npcompTypeIsATorchLinearParams(MlirType t);
/// Gets the !torch.LinearParams type.
MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.qint8 type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.qint8 type
bool npcompTypeIsATorchQInt8(MlirType t);
/// Gets the !torch.qint8 type.
MlirType npcompTorchQInt8TypeGet(MlirContext context);
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.tensor type
bool npcompTypeIsATorchNonValueTensor(MlirType t);
/// Gets a !torch.tensor type.
///
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case). -
/// `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present.
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information.
MlirType
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===//
// torch.vtensor type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.vtensor type
bool npcompTypeIsATorchValueTensor(MlirType t);
/// Gets a !torch.vtensor type.
///
/// - `optionalSizes` is allowed to be null, meaning that no size
/// information is present (and `numSizes` is ignored in that case).
/// - `optionalDtype` is allowed to be null, meaning that no dtype
/// information is present.
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype);
/// Gets the !torch.tensor type with the least static information.
MlirType
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===//
// !torch.none type.
//===----------------------------------------------------------------------===//
/// Checks whether the given type is a !torch.none type
bool npcompTypeIsATorchNone(MlirType t);
/// Gets the !torch.none type.
MlirType npcompTorchNoneTypeGet(MlirContext context);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_TORCHTYPES_H

View File

@ -1,246 +0,0 @@
/*===-- npcomp-c/Types.h - NPComp custom types --------------------*- C -*-===*\
|* *|
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|* Exceptions. *|
|* See https://llvm.org/LICENSE.txt for license information. *|
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|* *|
\*===----------------------------------------------------------------------===*/
#ifndef NPCOMP_C_TYPES_H
#define NPCOMP_C_TYPES_H
#include "mlir-c/IR.h"
#ifdef __cplusplus
extern "C" {
#endif
/*============================================================================*/
/* Any dtype type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsAAnyDtype(MlirType t);
/** Gets the "any dtype" type. */
MlirType npcompAnyDtypeTypeGet(MlirContext context);
/*============================================================================*/
/* Bool type. */
/*============================================================================*/
/** Checks whether the given type is the Python "bool" type. */
int npcompTypeIsABool(MlirType t);
/** Gets the Python "bool" type. */
MlirType npcompBoolTypeGet(MlirContext context);
/*============================================================================*/
/* Bytes type. */
/*============================================================================*/
/** Checks whether the given type is the Python "bytes" type. */
int npcompTypeIsABytes(MlirType t);
/** Gets the Python "bytes" type. */
MlirType npcompBytesTypeGet(MlirContext context);
/*============================================================================*/
/* Dict type. */
/*============================================================================*/
/** Checks whether the given type is the Python "dict" type. */
int npcompTypeIsADict(MlirType t);
/** Gets the generic Python "dict" type. */
MlirType npcompDictTypeGet(MlirContext context);
/*============================================================================*/
/* List type. */
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsABasicpyList(MlirType t);
/** Gets the generic Python "list" type. */
MlirType npcompBasicpyListTypeGet(MlirContext context);
/*============================================================================*/
/* NDArray type. */
/*============================================================================*/
/** Checks whether the given type is an NdArray type. */
int npcompTypeIsANdArray(MlirType t);
/** Gets a numpy.NdArray type that is unranked. */
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType);
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
* unknown. */
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType);
/// Helper that gets an equivalent NdArrayType from a ShapedType.
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
/// Helper that converts an NdArrayType to a TensorType.
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType);
/*============================================================================*/
/* !basicpy.NoneType type. */
/*============================================================================*/
/** Checks whether the given type is a `!basicpy.NoneType`. */
int npcompTypeIsABasicpyNone(MlirType t);
/** Gets the `!basicpy.NoneType` type. */
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
/*============================================================================*/
/* SlotObject type. */
/*============================================================================*/
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes);
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsATuple(MlirType t);
/** Gets the generic Python "tuple" type. */
MlirType npcompTupleTypeGet(MlirContext context);
/*============================================================================*/
/* torch.nn.Module type. */
/*============================================================================*/
/** Checks whether the given type is a torch.nn.Module type */
int npcompTypeIsANnModule(MlirType t);
/** Gets the !torch.nn.Module type of the specified class. */
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
/*============================================================================*/
/* torch.optional type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.optional<T> type */
int npcompTypeIsAOptional(MlirType t);
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType);
/*============================================================================*/
/* torch.list type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.list<T> type */
int npcompTypeIsAList(MlirType t);
/** Gets the !torch.list<T> type with contained T. */
MlirType npcompListTypeGet(MlirType containedType);
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.Device type */
int npcompTypeIsADevice(MlirType t);
/** Gets the !torch.Device type. */
MlirType npcompDeviceTypeGet(MlirContext context);
/*============================================================================*/
/* torch.LinearParams type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.LinearParams type */
int npcompTypeIsALinearParams(MlirType t);
/** Gets the !torch.LinearParams type. */
MlirType npcompLinearParamsTypeGet(MlirContext context);
/*============================================================================*/
/* torch.qint8 type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.qint8 type */
int npcompTypeIsAQInt8(MlirType t);
/** Gets the !torch.qint8 type. */
MlirType npcompQInt8TypeGet(MlirContext context);
/*============================================================================*/
/* torch.tensor type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.tensor type */
int npcompTypeIsANonValueTensor(MlirType t);
/** Gets a !torch.tensor type.
*
* - `optionalSizes` is allowed to be null, meaning that no size information is
* present (and `numSizes` is ignored in that case).
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
* present.
*
*/
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype);
/** Gets the !torch.tensor type with the least static information. */
MlirType
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type);
/*============================================================================*/
/* torch.vtensor type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.vtensor type */
int npcompTypeIsAValueTensor(MlirType t);
/** Gets a !torch.vtensor type.
*
* - `optionalSizes` is allowed to be null, meaning that no size information is
* present (and `numSizes` is ignored in that case).
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
* present.
*
*/
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype);
/** Gets the !torch.tensor type with the least static information. */
MlirType
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
MlirType npcompValueTensorTypeGetFromShaped(MlirType type);
/*============================================================================*/
/* !torch.none type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.none type */
int npcompTypeIsATorchNone(MlirType t);
/** Gets the !torch.none type. */
MlirType npcompTorchNoneTypeGet(MlirContext context);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_TYPES_H

View File

@ -0,0 +1,116 @@
//===- BasicpyTypes.cpp - C Interface for basicpy types -------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/BasicpyTypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// Bool type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsABasicpyBool(MlirType t) {
return unwrap(t).isa<Basicpy::BoolType>();
}
MlirType npcompBasicpyBoolTypeGet(MlirContext context) {
return wrap(Basicpy::BoolType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// Bytes type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsABasicpyBytes(MlirType t) {
return unwrap(t).isa<Basicpy::BytesType>();
}
MlirType npcompBasicpyBytesTypeGet(MlirContext context) {
return wrap(Basicpy::BytesType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// Dict type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is the Python "dict" type. */
bool npcompTypeIsABasicpyDict(MlirType t) {
return unwrap(t).isa<Basicpy::DictType>();
}
/** Gets the generic Python "dict" type. */
MlirType npcompBasicpyDictTypeGet(MlirContext context) {
return wrap(Basicpy::DictType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// List type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is the Python "list" type. */
bool npcompTypeIsABasicpyList(MlirType t) {
return unwrap(t).isa<Basicpy::ListType>();
}
/** Gets the generic Python "dict" type. */
MlirType npcompBasicpyListTypeGet(MlirContext context) {
return wrap(Basicpy::ListType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// !basicpy.NoneType type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is a `!basicpy.NoneType`. */
bool npcompTypeIsANone(MlirType t) {
return unwrap(t).isa<Basicpy::NoneType>();
}
/** Gets the `!basicpy.NoneType` type. */
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
return wrap(Basicpy::NoneType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// SlotObject type.
//===----------------------------------------------------------------------===//
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes) {
MLIRContext *cppContext = unwrap(context);
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
SmallVector<Type> slotTypesCpp;
slotTypesCpp.resize(slotTypeCount);
for (intptr_t i = 0; i < slotTypeCount; ++i) {
slotTypesCpp[i] = unwrap(slotTypes[i]);
}
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
}
//===----------------------------------------------------------------------===//
// Tuple type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
bool npcompTypeIsABasicpyTuple(MlirType t) {
return unwrap(t).isa<Basicpy::TupleType>();
}
/** Gets the "any dtype" type. */
MlirType npcompBasicpyTupleTypeGet(MlirContext context) {
return wrap(Basicpy::TupleType::get(unwrap(context)));
}

View File

@ -7,7 +7,9 @@ set(LLVM_LINK_COMPONENTS
add_npcomp_library(NPCOMPCAPI
InitLLVM.cpp
Registration.cpp
Types.cpp
BasicpyTypes.cpp
NumpyTypes.cpp
TorchTypes.cpp
LINK_LIBS PUBLIC
MLIRExecutionEngine

View File

@ -0,0 +1,56 @@
//===- NumpyTypes.cpp - C Interface for numpy types -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/NumpyTypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// !numpy.any_dtype type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsANumpyAnyDtype(MlirType t) {
return unwrap(t).isa<Numpy::AnyDtypeType>();
}
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// NDArray type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsANumpyNdArray(MlirType t) {
return unwrap(t).isa<Numpy::NdArrayType>();
}
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType) {
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
}
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType) {
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
}
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType) {
return wrap(Numpy::NdArrayType::getFromShapedType(
unwrap(shapedType).cast<ShapedType>()));
}
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType) {
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
}

View File

@ -0,0 +1,165 @@
//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/TorchTypes.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// torch.nn.Module type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is a torch.nn.Module type */
bool npcompTypeIsATorchNnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>();
}
/** Gets the torch.nn.Module type of the specified class. */
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
MlirStringRef className) {
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
}
//===----------------------------------------------------------------------===//
// torch.optional type.
//===----------------------------------------------------------------------===//
/** Checks whether the given type is a !torch.optional<T> type */
bool npcompTypeIsATorchOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>();
}
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}
//===----------------------------------------------------------------------===//
// torch.list<T> type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchList(MlirType t) {
return unwrap(t).isa<Torch::ListType>();
}
MlirType npcompTorchListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType)));
}
//===----------------------------------------------------------------------===//
// torch.Device type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchDevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>();
}
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.LinearParams type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchLinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>();
}
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.qint8 type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>();
}
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context)));
}
//===----------------------------------------------------------------------===//
// torch.tensor type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>();
}
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context) {
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
unwrap(context)));
}
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
return wrap(Torch::NonValueTensorType::getFromShaped(
unwrap(type).cast<ShapedType>()));
}
//===----------------------------------------------------------------------===//
// torch.vtensor type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>();
}
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
return wrap(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
}
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
return wrap(
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
}
//===----------------------------------------------------------------------===//
// torch.none type.
//===----------------------------------------------------------------------===//
bool npcompTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>();
}
MlirType npcompTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context)));
}

View File

@ -1,303 +0,0 @@
//===- Types.cpp - C Interface for NPComp types ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/Types.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/IR/BuiltinTypes.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
using namespace mlir;
using namespace mlir::NPCOMP;
/*============================================================================*/
/* Any dtype type. */
/*============================================================================*/
int npcompTypeIsAAnyDtype(MlirType t) {
return unwrap(t).isa<Numpy::AnyDtypeType>();
}
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
}
/*============================================================================*/
/* Bool type. */
/*============================================================================*/
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<Basicpy::BoolType>(); }
MlirType npcompBoolTypeGet(MlirContext context) {
return wrap(Basicpy::BoolType::get(unwrap(context)));
}
/*============================================================================*/
/* Bytes type. */
/*============================================================================*/
int npcompTypeIsABytes(MlirType t) {
return unwrap(t).isa<Basicpy::BytesType>();
}
MlirType npcompBytesTypeGet(MlirContext context) {
return wrap(Basicpy::BytesType::get(unwrap(context)));
}
/*============================================================================*/
/* Dict type. */
/*============================================================================*/
/** Checks whether the given type is the Python "dict" type. */
int npcompTypeIsADict(MlirType t) { return unwrap(t).isa<Basicpy::DictType>(); }
/** Gets the generic Python "dict" type. */
MlirType npcompDictTypeGet(MlirContext context) {
return wrap(Basicpy::DictType::get(unwrap(context)));
}
/*============================================================================*/
/* List type. */
/*============================================================================*/
/** Checks whether the given type is the Python "list" type. */
int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
/** Gets the generic Python "dict" type. */
MlirType npcompBasicpyListTypeGet(MlirContext context) {
return wrap(Basicpy::ListType::get(unwrap(context)));
}
/*============================================================================*/
/* NDArray type. */
/*============================================================================*/
int npcompTypeIsANdArray(MlirType t) {
return unwrap(t).isa<Numpy::NdArrayType>();
}
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) {
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
}
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
MlirType elementType) {
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
}
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
return wrap(Numpy::NdArrayType::getFromShapedType(
unwrap(shapedType).cast<ShapedType>()));
}
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) {
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
}
/*============================================================================*/
/* !basicpy.NoneType type. */
/*============================================================================*/
/** Checks whether the given type is a `!basicpy.NoneType`. */
int npcompTypeIsANone(MlirType t) { return unwrap(t).isa<Basicpy::NoneType>(); }
/** Gets the `!basicpy.NoneType` type. */
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
return wrap(Basicpy::NoneType::get(unwrap(context)));
}
/*============================================================================*/
/* SlotObject type. */
/*============================================================================*/
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
intptr_t slotTypeCount,
const MlirType *slotTypes) {
MLIRContext *cppContext = unwrap(context);
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
SmallVector<Type> slotTypesCpp;
slotTypesCpp.resize(slotTypeCount);
for (intptr_t i = 0; i < slotTypeCount; ++i) {
slotTypesCpp[i] = unwrap(slotTypes[i]);
}
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
}
/*============================================================================*/
/* Tuple type. */
/*============================================================================*/
/** Checks whether the given type is the special "any dtype" type that is used
* to signal an NDArray or tensor of unknown type. */
int npcompTypeIsATuple(MlirType t) {
return unwrap(t).isa<Basicpy::TupleType>();
}
/** Gets the "any dtype" type. */
MlirType npcompTupleTypeGet(MlirContext context) {
return wrap(Basicpy::TupleType::get(unwrap(context)));
}
/*============================================================================*/
/* torch.nn.Module type. */
/*============================================================================*/
/** Checks whether the given type is a torch.nn.Module type */
int npcompTypeIsANnModule(MlirType t) {
return unwrap(t).isa<Torch::NnModuleType>();
}
/** Gets the torch.nn.Module type of the specified class. */
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
}
/*============================================================================*/
/* torch.optional type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.optional<T> type */
int npcompTypeIsAOptional(MlirType t) {
return unwrap(t).isa<Torch::OptionalType>();
}
/** Gets the !torch.optional<T> type with subtype T. */
MlirType npcompOptionalTypeGet(MlirType containedType) {
return wrap(Torch::OptionalType::get(unwrap(containedType)));
}
/*============================================================================*/
/* torch.list type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.list<T> type */
int npcompTypeIsAList(MlirType t) {
return unwrap(t).isa<Torch::ListType>();
}
/** Gets the !torch.List<T> type with contained T. */
MlirType npcompListTypeGet(MlirType containedType) {
return wrap(Torch::ListType::get(unwrap(containedType)));
}
/*============================================================================*/
/* torch.Device type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.Device type */
int npcompTypeIsADevice(MlirType t) {
return unwrap(t).isa<Torch::DeviceType>();
}
/** Gets the !torch.Device type. */
MlirType npcompDeviceTypeGet(MlirContext context) {
return wrap(Torch::DeviceType::get(unwrap(context)));
}
/*============================================================================*/
/* torch.LinearParams type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.LinearParams type */
int npcompTypeIsALinearParams(MlirType t) {
return unwrap(t).isa<Torch::LinearParamsType>();
}
/** Gets the !torch.LinearParams type. */
MlirType npcompLinearParamsTypeGet(MlirContext context) {
return wrap(Torch::LinearParamsType::get(unwrap(context)));
}
/*============================================================================*/
/* torch.qint8 type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.qint8 type */
int npcompTypeIsAQInt8(MlirType t) {
return unwrap(t).isa<Torch::QInt8Type>();
}
/** Gets the !torch.qint8 type. */
MlirType npcompQInt8TypeGet(MlirContext context) {
return wrap(Torch::QInt8Type::get(unwrap(context)));
}
/*============================================================================*/
/* torch.tensor type. */
/*============================================================================*/
int npcompTypeIsANonValueTensor(MlirType t) {
return unwrap(t).isa<Torch::NonValueTensorType>();
}
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::NonValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
unwrap(context)));
}
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type) {
return wrap(Torch::NonValueTensorType::getFromShaped(
unwrap(type).cast<ShapedType>()));
}
/*============================================================================*/
/* torch.vtensor type. */
/*============================================================================*/
int npcompTypeIsAValueTensor(MlirType t) {
return unwrap(t).isa<Torch::ValueTensorType>();
}
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
const int64_t *optionalSizes,
MlirType optionalDtype) {
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
if (optionalSizes)
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
return wrap(Torch::ValueTensorType::get(
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
}
MlirType
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
return wrap(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
}
MlirType npcompValueTensorTypeGetFromShaped(MlirType type) {
return wrap(
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
}
/*============================================================================*/
/* torch.none type. */
/*============================================================================*/
/** Checks whether the given type is a !torch.none type */
int npcompTypeIsATorchNone(MlirType t) {
return unwrap(t).isa<Torch::NoneType>();
}
/** Gets the !torch.none type. */
MlirType npcompTorchNoneTypeGet(MlirContext context) {
return wrap(Torch::NoneType::get(unwrap(context)));
}

View File

@ -12,9 +12,10 @@
#include "mlir-c/BuiltinAttributes.h"
#include "mlir-c/BuiltinTypes.h"
#include "mlir-c/Diagnostics.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/InitLLVM.h"
#include "npcomp-c/NumpyTypes.h"
#include "npcomp-c/Registration.h"
#include "npcomp-c/Types.h"
#include "npcomp/Python/PybindUtils.h"
#ifdef NPCOMP_ENABLE_REFJIT
@ -27,21 +28,21 @@ MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
if (!mlirTypeIsAShaped(shaped_type)) {
throw py::raiseValueError("type is not a shaped type");
}
return npcompNdArrayTypeGetFromShaped(shaped_type);
return npcompNumpyNdArrayTypeGetFromShaped(shaped_type);
}
MlirType ndarrayToTensorType(MlirType ndarray_type) {
if (!npcompTypeIsANdArray(ndarray_type)) {
if (!npcompTypeIsANumpyNdArray(ndarray_type)) {
throw py::raiseValueError("type is not an ndarray type");
}
return npcompNdArrayTypeToTensor(ndarray_type);
return npcompNumpyNdArrayTypeToTensor(ndarray_type);
}
MlirType slotObjectType(MlirContext context, const std::string &className,
const std::vector<MlirType> &slotTypes) {
MlirStringRef classNameSr{className.data(), className.size()};
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
slotTypes.data());
return ::npcompBasicPySlotObjectTypeGet(context, classNameSr,
slotTypes.size(), slotTypes.data());
}
// TODO: Move this upstream.

View File

@ -12,8 +12,9 @@
#include "mlir-c/IR.h"
#include "mlir-c/Registration.h"
#include "npcomp-c/BasicpyTypes.h"
#include "npcomp-c/NumpyTypes.h"
#include "npcomp-c/Registration.h"
#include "npcomp-c/Types.h"
#include <assert.h>
#include <math.h>
@ -24,30 +25,31 @@
// Dumps an instance of all NPComp types.
static int printStandardTypes(MlirContext ctx) {
// Bool type.
MlirType boolType = npcompBoolTypeGet(ctx);
if (!npcompTypeIsABool(boolType))
MlirType boolType = npcompBasicpyBoolTypeGet(ctx);
if (!npcompTypeIsABasicpyBool(boolType))
return 1;
mlirTypeDump(boolType);
fprintf(stderr, "\n");
// Bytes type.
MlirType bytesType = npcompBytesTypeGet(ctx);
if (!npcompTypeIsABytes(bytesType))
MlirType bytesType = npcompBasicpyBytesTypeGet(ctx);
if (!npcompTypeIsABasicpyBytes(bytesType))
return 1;
mlirTypeDump(bytesType);
fprintf(stderr, "\n");
// Any dtype.
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
if (!npcompTypeIsAAnyDtype(anyDtype))
if (!npcompTypeIsANumpyAnyDtype(anyDtype))
return 2;
mlirTypeDump(anyDtype);
fprintf(stderr, "\n");
// Ranked NdArray.
int64_t fourDim = 4;
MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType);
if (!npcompTypeIsANdArray(rankedNdArray))
MlirType rankedNdArray =
npcompNumpyNdArrayTypeGetRanked(1, &fourDim, boolType);
if (!npcompTypeIsANumpyNdArray(rankedNdArray))
return 3;
mlirTypeDump(rankedNdArray);
fprintf(stderr, "\n");