mirror of https://github.com/llvm/torch-mlir
Make C API files more consistent
- Make consistent with MLIR Core - Use `//` or `///` comments. - Use `bool` type for booleans - No duplicated comments in .cpp files - Split types into separate files `{Basicpy,Numpy,Torch}Types.h` - Add dialect prefix consistently to C API symbols. We have lots of similarly named types (e.g. "list" type in basicpy and torch).pull/222/head
parent
db282fd1b4
commit
6b2424512b
|
@ -11,7 +11,7 @@
|
|||
|
||||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
|
||||
#include <ATen/core/function_schema.h>
|
||||
|
@ -512,9 +512,8 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
|||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||
}
|
||||
if (ival.isList()) {
|
||||
return npcompListTypeGet(
|
||||
typeMapper.mapFromTorchType(
|
||||
loc, ival.toList().elementType()));
|
||||
return npcompTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
||||
}
|
||||
if (ival.isNone()) {
|
||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
||||
|
@ -530,7 +529,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
funcBuilder->getEntryBlock(), "torch.tensor", loc,
|
||||
npcompNonValueTensorTypeGetFromShaped(
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -132,7 +132,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
|||
|
||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||
std::vector<MlirValue> &elements) {
|
||||
MlirType resultType = npcompListTypeGet(elementType);
|
||||
MlirType resultType = npcompTorchListTypeGet(elementType);
|
||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||
mlirOperationStateAddResults(state, 1, &resultType);
|
||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
#include "caffe2/core/scope_guard.h"
|
||||
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||
|
@ -170,7 +171,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
|
|||
|
||||
MlirOperation nnModule = createMlirOperation(
|
||||
"torch.nn_module", loc,
|
||||
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||
mlirRegionCreate());
|
||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||
|
@ -240,7 +241,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||
|
||||
if (ivalue.isBool()) {
|
||||
MlirType type = npcompBoolTypeGet(context);
|
||||
MlirType type = npcompBasicpyBoolTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "basicpy.bool_constant", loc, type,
|
||||
toMlirNamedAttribute("value",
|
||||
|
@ -269,11 +270,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
for (const c10::IValue &elem : list) {
|
||||
elems.push_back(importIValue(elem));
|
||||
}
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc,
|
||||
npcompListTypeGet(
|
||||
typeMapper.mapFromTorchType(
|
||||
loc, list.elementType())), elems);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.prim.ListConstruct", loc,
|
||||
npcompTorchListTypeGet(
|
||||
typeMapper.mapFromTorchType(loc, list.elementType())),
|
||||
elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTuple()) {
|
||||
|
@ -284,7 +285,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation =
|
||||
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
||||
npcompTupleTypeGet(context), elems);
|
||||
npcompBasicpyTupleTypeGet(context), elems);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
if (ivalue.isTensor()) {
|
||||
|
@ -294,7 +295,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
return importModule(ivalue.toModule());
|
||||
}
|
||||
if (ivalue.isString()) {
|
||||
MlirType type = npcompBytesTypeGet(context);
|
||||
MlirType type = npcompBasicpyBytesTypeGet(context);
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "basicpy.bytes_constant", loc, type,
|
||||
toMlirNamedAttribute(
|
||||
|
@ -325,7 +326,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
|||
}
|
||||
MlirOperation operation = createMlirOperationAtEnd(
|
||||
importBlock, "torch.linear_params.create", loc,
|
||||
npcompLinearParamsTypeGet(context), weightValue, biasValue);
|
||||
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
||||
return mlirOperationGetResult(operation, 0);
|
||||
}
|
||||
}
|
||||
|
@ -345,7 +346,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp =
|
||||
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
||||
npcompNonValueTensorTypeGetFromShaped(
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||
|
@ -360,7 +361,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
// compiler stages that are building a statically modeled quantization
|
||||
// representation will need to convert this to their representation.
|
||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||
MlirType quantizedTensorType = npcompNonValueTensorTypeGet(
|
||||
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
|
||||
context, shape.size(), shape.data(),
|
||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||
|
@ -504,11 +505,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
|||
mlirLocationUnknownGet(context), *maybeDtype);
|
||||
MlirType typeBound;
|
||||
if (hasValueSemantics) {
|
||||
typeBound = npcompValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
} else {
|
||||
typeBound = npcompNonValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
||||
shape.data(), dtype);
|
||||
}
|
||||
|
||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace torch_mlir;
|
||||
|
@ -145,7 +146,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
|||
loc, appendToBlock),
|
||||
mlirRegionCreate());
|
||||
mapResults(node, operation);
|
||||
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
|
||||
std::vector<MlirType> terminatorOperandTypes = {
|
||||
npcompBasicpyBoolTypeGet(context)};
|
||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||
resultTypes.begin(), resultTypes.end());
|
||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||
|
|
|
@ -10,7 +10,8 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -23,14 +24,14 @@ MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
|||
|
||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
||||
return createMlirOperation(
|
||||
"basicpy.bool_constant", loc, npcompBoolTypeGet(context),
|
||||
"basicpy.bool_constant", loc, npcompBasicpyBoolTypeGet(context),
|
||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
||||
}
|
||||
|
||||
MlirOperation OpBuilder::createBytesConstant(MlirLocation loc,
|
||||
const std::string &value) {
|
||||
return createMlirOperation(
|
||||
"basicpy.bytes_constant", loc, npcompBytesTypeGet(context),
|
||||
"basicpy.bytes_constant", loc, npcompBasicpyBytesTypeGet(context),
|
||||
toMlirNamedAttribute("value",
|
||||
mlirStringAttrGet(context, toMlirStringRef(value))));
|
||||
}
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
|
@ -54,7 +55,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|||
case ScalarType::Long:
|
||||
return mlirIntegerTypeSignedGet(context, 64);
|
||||
case ScalarType::Bool:
|
||||
return npcompBoolTypeGet(context);
|
||||
return npcompBasicpyBoolTypeGet(context);
|
||||
case ScalarType::Double:
|
||||
return mlirF64TypeGet(context);
|
||||
case ScalarType::Float:
|
||||
|
@ -64,7 +65,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
|||
case ScalarType::Half:
|
||||
return mlirF16TypeGet(context);
|
||||
case ScalarType::QInt8:
|
||||
return npcompQInt8TypeGet(context);
|
||||
return npcompTorchQInt8TypeGet(context);
|
||||
default: {
|
||||
return {nullptr};
|
||||
}
|
||||
|
@ -103,7 +104,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
|||
|
||||
// Individually handle the custom classes that we know about.
|
||||
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
||||
return npcompLinearParamsTypeGet(context);
|
||||
return npcompTorchLinearParamsTypeGet(context);
|
||||
}
|
||||
|
||||
// At this point, we know that the type is indeed a custom class type, but
|
||||
|
@ -134,11 +135,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
auto &sizes = tensorType->symbolic_sizes();
|
||||
if (!sizes.rank()) {
|
||||
// Unranked.
|
||||
return npcompNonValueTensorTypeGet(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
return npcompTorchNonValueTensorTypeGet(context,
|
||||
/*numSizes=*/0,
|
||||
/*optionalSizes=*/nullptr,
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
// Ranked with possibly dynamic dims.
|
||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||
|
@ -148,10 +149,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
auto shapeSymbol = symbolicShape[i];
|
||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||
}
|
||||
return npcompNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
|
||||
/*optionalSizes=*/dims.data(),
|
||||
/*optionalDtype=*/
|
||||
elementType);
|
||||
}
|
||||
case TypeKind::ClassType: {
|
||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||
|
@ -161,15 +162,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
}
|
||||
auto maybeName = classType->name();
|
||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||
return npcompNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||
}
|
||||
case TypeKind::FloatType: {
|
||||
return mlirF64TypeGet(context);
|
||||
}
|
||||
case TypeKind::OptionalType: {
|
||||
return npcompOptionalTypeGet(
|
||||
mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::IntType: {
|
||||
return mlirIntegerTypeGet(context, 64);
|
||||
|
@ -178,22 +178,21 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
|||
return npcompTorchNoneTypeGet(context);
|
||||
}
|
||||
case TypeKind::BoolType: {
|
||||
return npcompBoolTypeGet(context);
|
||||
return npcompBasicpyBoolTypeGet(context);
|
||||
}
|
||||
case TypeKind::ListType: {
|
||||
return npcompListTypeGet(
|
||||
mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
return npcompTorchListTypeGet(mapFromTorchType(
|
||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||
}
|
||||
case TypeKind::TupleType: {
|
||||
// TODO: Don't lose the element type information.
|
||||
return npcompTupleTypeGet(context);
|
||||
return npcompBasicpyTupleTypeGet(context);
|
||||
}
|
||||
case TypeKind::StringType: {
|
||||
return npcompBytesTypeGet(context);
|
||||
return npcompBasicpyBytesTypeGet(context);
|
||||
}
|
||||
case TypeKind::DeviceObjType: {
|
||||
return npcompDeviceTypeGet(context);
|
||||
return npcompTorchDeviceTypeGet(context);
|
||||
}
|
||||
default: {
|
||||
std::stringstream message;
|
||||
|
@ -216,8 +215,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
|||
// just erase them and let the compiler decide.
|
||||
|
||||
auto sizes = tensor.sizes();
|
||||
return npcompNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
||||
elementType);
|
||||
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
||||
elementType);
|
||||
}
|
||||
|
||||
MlirType
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
//===-- npcomp-c/BasicpyTypes.h - C API for basicpy types ---------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_C_BASICPYTYPES_H
|
||||
#define NPCOMP_C_BASICPYTYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.BoolType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is the Python "bool" type.
|
||||
bool npcompTypeIsABasicpyBool(MlirType t);
|
||||
|
||||
/// Gets the Python "bool" type.
|
||||
MlirType npcompBasicpyBoolTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.BytesType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is the Python "bytes" type.
|
||||
bool npcompTypeIsABasicpyBytes(MlirType t);
|
||||
|
||||
/// Gets the Python "bytes" type.
|
||||
MlirType npcompBasicpyBytesTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.DictType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is the Python "dict" type.
|
||||
bool npcompTypeIsABasicpyDict(MlirType t);
|
||||
|
||||
/// Gets the generic Python "dict" type.
|
||||
MlirType npcompBasicpyDictTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// List type
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is the Python "list" type.
|
||||
bool npcompTypeIsABasicpyList(MlirType t);
|
||||
|
||||
/// Gets the generic Python "list" type.
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.NoneType type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a `!basicpy.NoneType`.
|
||||
bool npcompTypeIsABasicpyNone(MlirType t);
|
||||
|
||||
/// Gets the `!basicpy.NoneType` type.
|
||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlotObject type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
|
||||
MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.TupleType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a `!basicpy.TupleType`.
|
||||
bool npcompTypeIsABasicpyTuple(MlirType t);
|
||||
|
||||
/// Gets the generic Python "tuple" type.
|
||||
MlirType npcompBasicpyTupleTypeGet(MlirContext context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_BASICPYTYPES_H
|
|
@ -0,0 +1,55 @@
|
|||
//===-- npcomp-c/NumpyTypes.h - C API for numpy types -------------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_C_NUMPYTYPES_H
|
||||
#define NPCOMP_C_NUMPYTYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !numpy.any_dtype type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is the special "any dtype" type that is used
|
||||
// to signal an NDArray or tensor of unknown type.
|
||||
bool npcompTypeIsANumpyAnyDtype(MlirType t);
|
||||
|
||||
/// Gets the "any dtype" type.
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NDArray type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is an NdArray type.
|
||||
bool npcompTypeIsANumpyNdArray(MlirType t);
|
||||
|
||||
/// Gets a numpy.NdArray type that is unranked.
|
||||
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType);
|
||||
|
||||
/// Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
||||
/// unknown.
|
||||
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType);
|
||||
|
||||
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||
|
||||
/// Helper that converts an NdArrayType to a TensorType.
|
||||
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_NUMPYTYPES_H
|
|
@ -0,0 +1,143 @@
|
|||
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_C_TORCHTYPES_H
|
||||
#define NPCOMP_C_TORCHTYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.nn.Module type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a torch.nn.Module type
|
||||
bool npcompTypeIsATorchNnModule(MlirType t);
|
||||
|
||||
/// Gets the !torch.nn.Module type of the specified class.
|
||||
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||
MlirStringRef className);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.optional type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.optional<T> type
|
||||
bool npcompTypeIsATorchOptional(MlirType t);
|
||||
|
||||
/// Gets the !torch.optional<T> type with subtype T.
|
||||
MlirType npcompTorchOptionalTypeGet(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.list<T> type
|
||||
bool npcompTypeIsATorchList(MlirType t);
|
||||
|
||||
/// Gets the !torch.list<T> type with contained T.
|
||||
MlirType npcompTorchListTypeGet(MlirType containedType);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.Device type
|
||||
bool npcompTypeIsATorchDevice(MlirType t);
|
||||
|
||||
/// Gets the !torch.Device type.
|
||||
MlirType npcompTorchDeviceTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.LinearParams type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.LinearParams type
|
||||
bool npcompTypeIsATorchLinearParams(MlirType t);
|
||||
|
||||
/// Gets the !torch.LinearParams type.
|
||||
MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint8 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.qint8 type
|
||||
bool npcompTypeIsATorchQInt8(MlirType t);
|
||||
|
||||
/// Gets the !torch.qint8 type.
|
||||
MlirType npcompTorchQInt8TypeGet(MlirContext context);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.tensor type
|
||||
bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
||||
|
||||
/// Gets a !torch.tensor type.
|
||||
///
|
||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||
/// information is present (and `numSizes` is ignored in that case). -
|
||||
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
/// information is present.
|
||||
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||
intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/// Gets the !torch.tensor type with the least static information.
|
||||
MlirType
|
||||
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.vtensor type
|
||||
bool npcompTypeIsATorchValueTensor(MlirType t);
|
||||
|
||||
/// Gets a !torch.vtensor type.
|
||||
///
|
||||
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||
/// information is present (and `numSizes` is ignored in that case).
|
||||
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||
/// information is present.
|
||||
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/// Gets the !torch.tensor type with the least static information.
|
||||
MlirType
|
||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Checks whether the given type is a !torch.none type
|
||||
bool npcompTypeIsATorchNone(MlirType t);
|
||||
|
||||
/// Gets the !torch.none type.
|
||||
MlirType npcompTorchNoneTypeGet(MlirContext context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_TORCHTYPES_H
|
|
@ -1,246 +0,0 @@
|
|||
/*===-- npcomp-c/Types.h - NPComp custom types --------------------*- C -*-===*\
|
||||
|* *|
|
||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||
|* Exceptions. *|
|
||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
|
||||
#ifndef NPCOMP_C_TYPES_H
|
||||
#define NPCOMP_C_TYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsAAnyDtype(MlirType t);
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "bool" type. */
|
||||
int npcompTypeIsABool(MlirType t);
|
||||
|
||||
/** Gets the Python "bool" type. */
|
||||
MlirType npcompBoolTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bytes type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "bytes" type. */
|
||||
int npcompTypeIsABytes(MlirType t);
|
||||
|
||||
/** Gets the Python "bytes" type. */
|
||||
MlirType npcompBytesTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Dict type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "dict" type. */
|
||||
int npcompTypeIsADict(MlirType t);
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompDictTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* List type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsABasicpyList(MlirType t);
|
||||
|
||||
/** Gets the generic Python "list" type. */
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is an NdArray type. */
|
||||
int npcompTypeIsANdArray(MlirType t);
|
||||
|
||||
/** Gets a numpy.NdArray type that is unranked. */
|
||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType);
|
||||
|
||||
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
||||
* unknown. */
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType);
|
||||
|
||||
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||
|
||||
/// Helper that converts an NdArrayType to a TensorType.
|
||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* !basicpy.NoneType type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
||||
int npcompTypeIsABasicpyNone(MlirType t);
|
||||
|
||||
/** Gets the `!basicpy.NoneType` type. */
|
||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* SlotObject type. */
|
||||
/*============================================================================*/
|
||||
|
||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsATuple(MlirType t);
|
||||
|
||||
/** Gets the generic Python "tuple" type. */
|
||||
MlirType npcompTupleTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.nn.Module type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a torch.nn.Module type */
|
||||
int npcompTypeIsANnModule(MlirType t);
|
||||
|
||||
/** Gets the !torch.nn.Module type of the specified class. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.optional type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.optional<T> type */
|
||||
int npcompTypeIsAOptional(MlirType t);
|
||||
|
||||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.list type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.list<T> type */
|
||||
int npcompTypeIsAList(MlirType t);
|
||||
|
||||
/** Gets the !torch.list<T> type with contained T. */
|
||||
MlirType npcompListTypeGet(MlirType containedType);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.Device type */
|
||||
int npcompTypeIsADevice(MlirType t);
|
||||
|
||||
/** Gets the !torch.Device type. */
|
||||
MlirType npcompDeviceTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.LinearParams type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.LinearParams type */
|
||||
int npcompTypeIsALinearParams(MlirType t);
|
||||
|
||||
/** Gets the !torch.LinearParams type. */
|
||||
MlirType npcompLinearParamsTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.qint8 type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.qint8 type */
|
||||
int npcompTypeIsAQInt8(MlirType t);
|
||||
|
||||
/** Gets the !torch.qint8 type. */
|
||||
MlirType npcompQInt8TypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.tensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.tensor type */
|
||||
int npcompTypeIsANonValueTensor(MlirType t);
|
||||
|
||||
/** Gets a !torch.tensor type.
|
||||
*
|
||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
||||
* present (and `numSizes` is ignored in that case).
|
||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
||||
* present.
|
||||
*
|
||||
*/
|
||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/** Gets the !torch.tensor type with the least static information. */
|
||||
MlirType
|
||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.vtensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.vtensor type */
|
||||
int npcompTypeIsAValueTensor(MlirType t);
|
||||
|
||||
/** Gets a !torch.vtensor type.
|
||||
*
|
||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
||||
* present (and `numSizes` is ignored in that case).
|
||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
||||
* present.
|
||||
*
|
||||
*/
|
||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype);
|
||||
|
||||
/** Gets the !torch.tensor type with the least static information. */
|
||||
MlirType
|
||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||
|
||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type);
|
||||
|
||||
/*============================================================================*/
|
||||
/* !torch.none type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.none type */
|
||||
int npcompTypeIsATorchNone(MlirType t);
|
||||
|
||||
/** Gets the !torch.none type. */
|
||||
MlirType npcompTorchNoneTypeGet(MlirContext context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_TYPES_H
|
|
@ -0,0 +1,116 @@
|
|||
//===- BasicpyTypes.cpp - C Interface for basicpy types -------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bool type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsABasicpyBool(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::BoolType>();
|
||||
}
|
||||
|
||||
MlirType npcompBasicpyBoolTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::BoolType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Bytes type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsABasicpyBytes(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::BytesType>();
|
||||
}
|
||||
|
||||
MlirType npcompBasicpyBytesTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::BytesType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dict type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is the Python "dict" type. */
|
||||
bool npcompTypeIsABasicpyDict(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::DictType>();
|
||||
}
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompBasicpyDictTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::DictType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// List type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
bool npcompTypeIsABasicpyList(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::ListType>();
|
||||
}
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::ListType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !basicpy.NoneType type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
||||
bool npcompTypeIsANone(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::NoneType>();
|
||||
}
|
||||
|
||||
/** Gets the `!basicpy.NoneType` type. */
|
||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SlotObject type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
|
||||
MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes) {
|
||||
MLIRContext *cppContext = unwrap(context);
|
||||
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
|
||||
SmallVector<Type> slotTypesCpp;
|
||||
slotTypesCpp.resize(slotTypeCount);
|
||||
for (intptr_t i = 0; i < slotTypeCount; ++i) {
|
||||
slotTypesCpp[i] = unwrap(slotTypes[i]);
|
||||
}
|
||||
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tuple type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
bool npcompTypeIsABasicpyTuple(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::TupleType>();
|
||||
}
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompBasicpyTupleTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
||||
}
|
|
@ -7,7 +7,9 @@ set(LLVM_LINK_COMPONENTS
|
|||
add_npcomp_library(NPCOMPCAPI
|
||||
InitLLVM.cpp
|
||||
Registration.cpp
|
||||
Types.cpp
|
||||
BasicpyTypes.cpp
|
||||
NumpyTypes.cpp
|
||||
TorchTypes.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRExecutionEngine
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
//===- NumpyTypes.cpp - C Interface for numpy types -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/NumpyTypes.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// !numpy.any_dtype type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsANumpyAnyDtype(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::AnyDtypeType>();
|
||||
}
|
||||
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
||||
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NDArray type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsANumpyNdArray(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::NdArrayType>();
|
||||
}
|
||||
|
||||
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType) {
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
|
||||
}
|
||||
|
||||
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType) {
|
||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
|
||||
}
|
||||
|
||||
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType) {
|
||||
return wrap(Numpy::NdArrayType::getFromShapedType(
|
||||
unwrap(shapedType).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType) {
|
||||
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
|
||||
}
|
|
@ -0,0 +1,165 @@
|
|||
//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/TorchTypes.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.nn.Module type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is a torch.nn.Module type */
|
||||
bool npcompTypeIsATorchNnModule(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NnModuleType>();
|
||||
}
|
||||
|
||||
/** Gets the torch.nn.Module type of the specified class. */
|
||||
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||
MlirStringRef className) {
|
||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.optional type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/** Checks whether the given type is a !torch.optional<T> type */
|
||||
bool npcompTypeIsATorchOptional(MlirType t) {
|
||||
return unwrap(t).isa<Torch::OptionalType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.list<T> type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchList(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ListType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchListTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.Device type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchDevice(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DeviceType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.LinearParams type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchLinearParams(MlirType t) {
|
||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.qint8 type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchQInt8(MlirType t) {
|
||||
return unwrap(t).isa<Torch::QInt8Type>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.tensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||
intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::NonValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||
MlirContext context) {
|
||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||
unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||
unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.vtensor type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::ValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// torch.none type.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
bool npcompTypeIsATorchNone(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NoneType>();
|
||||
}
|
||||
|
||||
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||
}
|
|
@ -1,303 +0,0 @@
|
|||
//===- Types.cpp - C Interface for NPComp types ---------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/Types.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Support.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsAAnyDtype(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::AnyDtypeType>();
|
||||
}
|
||||
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
||||
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<Basicpy::BoolType>(); }
|
||||
|
||||
MlirType npcompBoolTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::BoolType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bytes type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsABytes(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::BytesType>();
|
||||
}
|
||||
|
||||
MlirType npcompBytesTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::BytesType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Dict type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "dict" type. */
|
||||
int npcompTypeIsADict(MlirType t) { return unwrap(t).isa<Basicpy::DictType>(); }
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompDictTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::DictType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* List type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "list" type. */
|
||||
int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
|
||||
|
||||
/** Gets the generic Python "dict" type. */
|
||||
MlirType npcompBasicpyListTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::ListType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsANdArray(MlirType t) {
|
||||
return unwrap(t).isa<Numpy::NdArrayType>();
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) {
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType) {
|
||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
|
||||
return wrap(Numpy::NdArrayType::getFromShapedType(
|
||||
unwrap(shapedType).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) {
|
||||
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* !basicpy.NoneType type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
||||
int npcompTypeIsANone(MlirType t) { return unwrap(t).isa<Basicpy::NoneType>(); }
|
||||
|
||||
/** Gets the `!basicpy.NoneType` type. */
|
||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* SlotObject type. */
|
||||
/*============================================================================*/
|
||||
|
||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
||||
intptr_t slotTypeCount,
|
||||
const MlirType *slotTypes) {
|
||||
MLIRContext *cppContext = unwrap(context);
|
||||
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
|
||||
SmallVector<Type> slotTypesCpp;
|
||||
slotTypesCpp.resize(slotTypeCount);
|
||||
for (intptr_t i = 0; i < slotTypeCount; ++i) {
|
||||
slotTypesCpp[i] = unwrap(slotTypes[i]);
|
||||
}
|
||||
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Tuple type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsATuple(MlirType t) {
|
||||
return unwrap(t).isa<Basicpy::TupleType>();
|
||||
}
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompTupleTypeGet(MlirContext context) {
|
||||
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.nn.Module type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a torch.nn.Module type */
|
||||
int npcompTypeIsANnModule(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NnModuleType>();
|
||||
}
|
||||
|
||||
/** Gets the torch.nn.Module type of the specified class. */
|
||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
|
||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.optional type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.optional<T> type */
|
||||
int npcompTypeIsAOptional(MlirType t) {
|
||||
return unwrap(t).isa<Torch::OptionalType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.optional<T> type with subtype T. */
|
||||
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.list type. */
|
||||
/*============================================================================*/
|
||||
/** Checks whether the given type is a !torch.list<T> type */
|
||||
int npcompTypeIsAList(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ListType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.List<T> type with contained T. */
|
||||
MlirType npcompListTypeGet(MlirType containedType) {
|
||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.Device type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.Device type */
|
||||
int npcompTypeIsADevice(MlirType t) {
|
||||
return unwrap(t).isa<Torch::DeviceType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.Device type. */
|
||||
MlirType npcompDeviceTypeGet(MlirContext context) {
|
||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.LinearParams type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.LinearParams type */
|
||||
int npcompTypeIsALinearParams(MlirType t) {
|
||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.LinearParams type. */
|
||||
MlirType npcompLinearParamsTypeGet(MlirContext context) {
|
||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.qint8 type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.qint8 type */
|
||||
int npcompTypeIsAQInt8(MlirType t) {
|
||||
return unwrap(t).isa<Torch::QInt8Type>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.qint8 type. */
|
||||
MlirType npcompQInt8TypeGet(MlirContext context) {
|
||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.tensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsANonValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::NonValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||
unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||
unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.vtensor type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsAValueTensor(MlirType t) {
|
||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||
}
|
||||
|
||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||
const int64_t *optionalSizes,
|
||||
MlirType optionalDtype) {
|
||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||
if (optionalSizes)
|
||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||
return wrap(Torch::ValueTensorType::get(
|
||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||
}
|
||||
|
||||
MlirType
|
||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||
}
|
||||
|
||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type) {
|
||||
return wrap(
|
||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* torch.none type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is a !torch.none type */
|
||||
int npcompTypeIsATorchNone(MlirType t) {
|
||||
return unwrap(t).isa<Torch::NoneType>();
|
||||
}
|
||||
|
||||
/** Gets the !torch.none type. */
|
||||
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||
}
|
|
@ -12,9 +12,10 @@
|
|||
#include "mlir-c/BuiltinAttributes.h"
|
||||
#include "mlir-c/BuiltinTypes.h"
|
||||
#include "mlir-c/Diagnostics.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/InitLLVM.h"
|
||||
#include "npcomp-c/NumpyTypes.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
#include "npcomp/Python/PybindUtils.h"
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
|
@ -27,21 +28,21 @@ MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
|
|||
if (!mlirTypeIsAShaped(shaped_type)) {
|
||||
throw py::raiseValueError("type is not a shaped type");
|
||||
}
|
||||
return npcompNdArrayTypeGetFromShaped(shaped_type);
|
||||
return npcompNumpyNdArrayTypeGetFromShaped(shaped_type);
|
||||
}
|
||||
|
||||
MlirType ndarrayToTensorType(MlirType ndarray_type) {
|
||||
if (!npcompTypeIsANdArray(ndarray_type)) {
|
||||
if (!npcompTypeIsANumpyNdArray(ndarray_type)) {
|
||||
throw py::raiseValueError("type is not an ndarray type");
|
||||
}
|
||||
return npcompNdArrayTypeToTensor(ndarray_type);
|
||||
return npcompNumpyNdArrayTypeToTensor(ndarray_type);
|
||||
}
|
||||
|
||||
MlirType slotObjectType(MlirContext context, const std::string &className,
|
||||
const std::vector<MlirType> &slotTypes) {
|
||||
MlirStringRef classNameSr{className.data(), className.size()};
|
||||
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
|
||||
slotTypes.data());
|
||||
return ::npcompBasicPySlotObjectTypeGet(context, classNameSr,
|
||||
slotTypes.size(), slotTypes.data());
|
||||
}
|
||||
|
||||
// TODO: Move this upstream.
|
||||
|
|
|
@ -12,8 +12,9 @@
|
|||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "npcomp-c/BasicpyTypes.h"
|
||||
#include "npcomp-c/NumpyTypes.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
|
@ -24,30 +25,31 @@
|
|||
// Dumps an instance of all NPComp types.
|
||||
static int printStandardTypes(MlirContext ctx) {
|
||||
// Bool type.
|
||||
MlirType boolType = npcompBoolTypeGet(ctx);
|
||||
if (!npcompTypeIsABool(boolType))
|
||||
MlirType boolType = npcompBasicpyBoolTypeGet(ctx);
|
||||
if (!npcompTypeIsABasicpyBool(boolType))
|
||||
return 1;
|
||||
mlirTypeDump(boolType);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Bytes type.
|
||||
MlirType bytesType = npcompBytesTypeGet(ctx);
|
||||
if (!npcompTypeIsABytes(bytesType))
|
||||
MlirType bytesType = npcompBasicpyBytesTypeGet(ctx);
|
||||
if (!npcompTypeIsABasicpyBytes(bytesType))
|
||||
return 1;
|
||||
mlirTypeDump(bytesType);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Any dtype.
|
||||
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
|
||||
if (!npcompTypeIsAAnyDtype(anyDtype))
|
||||
if (!npcompTypeIsANumpyAnyDtype(anyDtype))
|
||||
return 2;
|
||||
mlirTypeDump(anyDtype);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Ranked NdArray.
|
||||
int64_t fourDim = 4;
|
||||
MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType);
|
||||
if (!npcompTypeIsANdArray(rankedNdArray))
|
||||
MlirType rankedNdArray =
|
||||
npcompNumpyNdArrayTypeGetRanked(1, &fourDim, boolType);
|
||||
if (!npcompTypeIsANumpyNdArray(rankedNdArray))
|
||||
return 3;
|
||||
mlirTypeDump(rankedNdArray);
|
||||
fprintf(stderr, "\n");
|
||||
|
|
Loading…
Reference in New Issue