diff --git a/frontends/pytorch/csrc/builder/acap_dispatch.cpp b/frontends/pytorch/csrc/builder/acap_dispatch.cpp index 656b79e75..32a45779d 100644 --- a/frontends/pytorch/csrc/builder/acap_dispatch.cpp +++ b/frontends/pytorch/csrc/builder/acap_dispatch.cpp @@ -11,7 +11,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/TorchTypes.h" #include "npcomp/Python/PybindUtils.h" #include @@ -512,9 +512,8 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc, return mlirIntegerTypeGet(funcBuilder->getContext(), 1); } if (ival.isList()) { - return npcompListTypeGet( - typeMapper.mapFromTorchType( - loc, ival.toList().elementType())); + return npcompTorchListTypeGet( + typeMapper.mapFromTorchType(loc, ival.toList().elementType())); } if (ival.isNone()) { return npcompTorchNoneTypeGet(funcBuilder->getContext()); @@ -530,7 +529,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) { MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirOperation tensorOp = createMlirOperationAtEnd( funcBuilder->getEntryBlock(), "torch.tensor", loc, - npcompNonValueTensorTypeGetFromShaped( + npcompTorchNonValueTensorTypeGetFromShaped( mlirAttributeGetType(denseElements)), toMlirNamedAttribute("value", denseElements)); MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0); diff --git a/frontends/pytorch/csrc/builder/func_builder.cpp b/frontends/pytorch/csrc/builder/func_builder.cpp index 3d175f9ec..2c896fc0c 100644 --- a/frontends/pytorch/csrc/builder/func_builder.cpp +++ b/frontends/pytorch/csrc/builder/func_builder.cpp @@ -12,7 +12,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/TorchTypes.h" using namespace torch_mlir; @@ -132,7 +132,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc, MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType, std::vector &elements) { - MlirType resultType = npcompListTypeGet(elementType); + MlirType resultType = npcompTorchListTypeGet(elementType); OperationStateHolder state{"torch.prim.ListConstruct", loc}; mlirOperationStateAddResults(state, 1, &resultType); mlirOperationStateAddOperands(state, elements.size(), elements.data()); diff --git a/frontends/pytorch/csrc/builder/ivalue_importer.cpp b/frontends/pytorch/csrc/builder/ivalue_importer.cpp index cfa345438..a4cbf2ce7 100644 --- a/frontends/pytorch/csrc/builder/ivalue_importer.cpp +++ b/frontends/pytorch/csrc/builder/ivalue_importer.cpp @@ -16,7 +16,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/BasicpyTypes.h" +#include "npcomp-c/TorchTypes.h" #include "caffe2/core/scope_guard.h" #include "ATen/native/quantized/cpu/packed_params.h" @@ -170,7 +171,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) { MlirOperation nnModule = createMlirOperation( "torch.nn_module", loc, - npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), + npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)), mlirRegionCreate()); MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0); mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr)); @@ -240,7 +241,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { MlirLocation loc = mlirLocationUnknownGet(context); if (ivalue.isBool()) { - MlirType type = npcompBoolTypeGet(context); + MlirType type = npcompBasicpyBoolTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.bool_constant", loc, type, toMlirNamedAttribute("value", @@ -269,11 +270,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { for (const c10::IValue &elem : list) { elems.push_back(importIValue(elem)); } - MlirOperation operation = - createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc, - npcompListTypeGet( - typeMapper.mapFromTorchType( - loc, list.elementType())), elems); + MlirOperation operation = createMlirOperationAtEnd( + importBlock, "torch.prim.ListConstruct", loc, + npcompTorchListTypeGet( + typeMapper.mapFromTorchType(loc, list.elementType())), + elems); return mlirOperationGetResult(operation, 0); } if (ivalue.isTuple()) { @@ -284,7 +285,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { } MlirOperation operation = createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc, - npcompTupleTypeGet(context), elems); + npcompBasicpyTupleTypeGet(context), elems); return mlirOperationGetResult(operation, 0); } if (ivalue.isTensor()) { @@ -294,7 +295,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { return importModule(ivalue.toModule()); } if (ivalue.isString()) { - MlirType type = npcompBytesTypeGet(context); + MlirType type = npcompBasicpyBytesTypeGet(context); MlirOperation operation = createMlirOperationAtEnd( importBlock, "basicpy.bytes_constant", loc, type, toMlirNamedAttribute( @@ -325,7 +326,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) { } MlirOperation operation = createMlirOperationAtEnd( importBlock, "torch.linear_params.create", loc, - npcompLinearParamsTypeGet(context), weightValue, biasValue); + npcompTorchLinearParamsTypeGet(context), weightValue, biasValue); return mlirOperationGetResult(operation, 0); } } @@ -345,7 +346,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc); MlirOperation tensorOp = createMlirOperationAtEnd(importBlock, "torch.tensor", loc, - npcompNonValueTensorTypeGetFromShaped( + npcompTorchNonValueTensorTypeGetFromShaped( mlirAttributeGetType(denseElements)), toMlirNamedAttribute("value", denseElements)); MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0); @@ -360,7 +361,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) { // compiler stages that are building a statically modeled quantization // representation will need to convert this to their representation. std::vector shape(tensor.sizes().begin(), tensor.sizes().end()); - MlirType quantizedTensorType = npcompNonValueTensorTypeGet( + MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet( context, shape.size(), shape.data(), typeMapper.mapFromTorchScalarType(tensor.scalar_type())); if (tensor.qscheme() == c10::kPerTensorAffine) { @@ -504,11 +505,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) { mlirLocationUnknownGet(context), *maybeDtype); MlirType typeBound; if (hasValueSemantics) { - typeBound = npcompValueTensorTypeGet(context, shape.size(), - shape.data(), dtype); + typeBound = npcompTorchValueTensorTypeGet(context, shape.size(), + shape.data(), dtype); } else { - typeBound = npcompNonValueTensorTypeGet(context, shape.size(), - shape.data(), dtype); + typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(), + shape.data(), dtype); } MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute( diff --git a/frontends/pytorch/csrc/builder/node_importer.cpp b/frontends/pytorch/csrc/builder/node_importer.cpp index 5ae631f30..6fe4b0872 100644 --- a/frontends/pytorch/csrc/builder/node_importer.cpp +++ b/frontends/pytorch/csrc/builder/node_importer.cpp @@ -15,7 +15,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/BasicpyTypes.h" +#include "npcomp-c/TorchTypes.h" namespace py = pybind11; using namespace torch_mlir; @@ -145,7 +146,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) { loc, appendToBlock), mlirRegionCreate()); mapResults(node, operation); - std::vector terminatorOperandTypes = {npcompBoolTypeGet(context)}; + std::vector terminatorOperandTypes = { + npcompBasicpyBoolTypeGet(context)}; terminatorOperandTypes.insert(terminatorOperandTypes.end(), resultTypes.begin(), resultTypes.end()); auto createTerminator = [&](c10::ArrayRef yieldedValues, diff --git a/frontends/pytorch/csrc/builder/op_builder.cpp b/frontends/pytorch/csrc/builder/op_builder.cpp index b712477e8..2d67e7dac 100644 --- a/frontends/pytorch/csrc/builder/op_builder.cpp +++ b/frontends/pytorch/csrc/builder/op_builder.cpp @@ -10,7 +10,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/BasicpyTypes.h" +#include "npcomp-c/TorchTypes.h" using namespace torch_mlir; @@ -23,14 +24,14 @@ MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) { MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) { return createMlirOperation( - "basicpy.bool_constant", loc, npcompBoolTypeGet(context), + "basicpy.bool_constant", loc, npcompBasicpyBoolTypeGet(context), toMlirNamedAttribute("value", mlirBoolAttrGet(context, value))); } MlirOperation OpBuilder::createBytesConstant(MlirLocation loc, const std::string &value) { return createMlirOperation( - "basicpy.bytes_constant", loc, npcompBytesTypeGet(context), + "basicpy.bytes_constant", loc, npcompBasicpyBytesTypeGet(context), toMlirNamedAttribute("value", mlirStringAttrGet(context, toMlirStringRef(value)))); } diff --git a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp index b6e34116d..ad493233b 100644 --- a/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp +++ b/frontends/pytorch/csrc/builder/torch_to_mlir_utils.cpp @@ -15,7 +15,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" -#include "npcomp-c/Types.h" +#include "npcomp-c/BasicpyTypes.h" +#include "npcomp-c/TorchTypes.h" using namespace torch_mlir; @@ -54,7 +55,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) { case ScalarType::Long: return mlirIntegerTypeSignedGet(context, 64); case ScalarType::Bool: - return npcompBoolTypeGet(context); + return npcompBasicpyBoolTypeGet(context); case ScalarType::Double: return mlirF64TypeGet(context); case ScalarType::Float: @@ -64,7 +65,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) { case ScalarType::Half: return mlirF16TypeGet(context); case ScalarType::QInt8: - return npcompQInt8TypeGet(context); + return npcompTorchQInt8TypeGet(context); default: { return {nullptr}; } @@ -103,7 +104,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc, // Individually handle the custom classes that we know about. if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") { - return npcompLinearParamsTypeGet(context); + return npcompTorchLinearParamsTypeGet(context); } // At this point, we know that the type is indeed a custom class type, but @@ -134,11 +135,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, auto &sizes = tensorType->symbolic_sizes(); if (!sizes.rank()) { // Unranked. - return npcompNonValueTensorTypeGet(context, - /*numSizes=*/0, - /*optionalSizes=*/nullptr, - /*optionalDtype=*/ - elementType); + return npcompTorchNonValueTensorTypeGet(context, + /*numSizes=*/0, + /*optionalSizes=*/nullptr, + /*optionalDtype=*/ + elementType); } // Ranked with possibly dynamic dims. auto &symbolicShape = tensorType->symbolic_sizes(); @@ -148,10 +149,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, auto shapeSymbol = symbolicShape[i]; dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1; } - return npcompNonValueTensorTypeGet(context, dims.size(), - /*optionalSizes=*/dims.data(), - /*optionalDtype=*/ - elementType); + return npcompTorchNonValueTensorTypeGet(context, dims.size(), + /*optionalSizes=*/dims.data(), + /*optionalDtype=*/ + elementType); } case TypeKind::ClassType: { const c10::ClassTypePtr &classType = torchType->cast(); @@ -161,15 +162,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, } auto maybeName = classType->name(); std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class"; - return npcompNnModuleTypeGet(context, toMlirStringRef(name)); + return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name)); } case TypeKind::FloatType: { return mlirF64TypeGet(context); } case TypeKind::OptionalType: { - return npcompOptionalTypeGet( - mapFromTorchType( - loc, torchType->cast()->getElementType())); + return npcompTorchOptionalTypeGet(mapFromTorchType( + loc, torchType->cast()->getElementType())); } case TypeKind::IntType: { return mlirIntegerTypeGet(context, 64); @@ -178,22 +178,21 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc, return npcompTorchNoneTypeGet(context); } case TypeKind::BoolType: { - return npcompBoolTypeGet(context); + return npcompBasicpyBoolTypeGet(context); } case TypeKind::ListType: { - return npcompListTypeGet( - mapFromTorchType( - loc, torchType->cast()->getElementType())); + return npcompTorchListTypeGet(mapFromTorchType( + loc, torchType->cast()->getElementType())); } case TypeKind::TupleType: { // TODO: Don't lose the element type information. - return npcompTupleTypeGet(context); + return npcompBasicpyTupleTypeGet(context); } case TypeKind::StringType: { - return npcompBytesTypeGet(context); + return npcompBasicpyBytesTypeGet(context); } case TypeKind::DeviceObjType: { - return npcompDeviceTypeGet(context); + return npcompTorchDeviceTypeGet(context); } default: { std::stringstream message; @@ -216,8 +215,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) { // just erase them and let the compiler decide. auto sizes = tensor.sizes(); - return npcompNonValueTensorTypeGet(context, sizes.size(), sizes.data(), - elementType); + return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(), + elementType); } MlirType diff --git a/include/npcomp-c/BasicpyTypes.h b/include/npcomp-c/BasicpyTypes.h new file mode 100644 index 000000000..065ad368a --- /dev/null +++ b/include/npcomp-c/BasicpyTypes.h @@ -0,0 +1,92 @@ +//===-- npcomp-c/BasicpyTypes.h - C API for basicpy types ---------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_C_BASICPYTYPES_H +#define NPCOMP_C_BASICPYTYPES_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// !basicpy.BoolType +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is the Python "bool" type. +bool npcompTypeIsABasicpyBool(MlirType t); + +/// Gets the Python "bool" type. +MlirType npcompBasicpyBoolTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !basicpy.BytesType +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is the Python "bytes" type. +bool npcompTypeIsABasicpyBytes(MlirType t); + +/// Gets the Python "bytes" type. +MlirType npcompBasicpyBytesTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !basicpy.DictType +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is the Python "dict" type. +bool npcompTypeIsABasicpyDict(MlirType t); + +/// Gets the generic Python "dict" type. +MlirType npcompBasicpyDictTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// List type +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is the Python "list" type. +bool npcompTypeIsABasicpyList(MlirType t); + +/// Gets the generic Python "list" type. +MlirType npcompBasicpyListTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// !basicpy.NoneType type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a `!basicpy.NoneType`. +bool npcompTypeIsABasicpyNone(MlirType t); + +/// Gets the `!basicpy.NoneType` type. +MlirType npcompBasicpyNoneTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// SlotObject type. +//===----------------------------------------------------------------------===// + +MlirType npcompBasicPySlotObjectTypeGet(MlirContext context, + MlirStringRef className, + intptr_t slotTypeCount, + const MlirType *slotTypes); + +//===----------------------------------------------------------------------===// +// !basicpy.TupleType +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a `!basicpy.TupleType`. +bool npcompTypeIsABasicpyTuple(MlirType t); + +/// Gets the generic Python "tuple" type. +MlirType npcompBasicpyTupleTypeGet(MlirContext context); + +#ifdef __cplusplus +} +#endif + +#endif // NPCOMP_C_BASICPYTYPES_H diff --git a/include/npcomp-c/NumpyTypes.h b/include/npcomp-c/NumpyTypes.h new file mode 100644 index 000000000..56867f541 --- /dev/null +++ b/include/npcomp-c/NumpyTypes.h @@ -0,0 +1,55 @@ +//===-- npcomp-c/NumpyTypes.h - C API for numpy types -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_C_NUMPYTYPES_H +#define NPCOMP_C_NUMPYTYPES_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// !numpy.any_dtype type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is the special "any dtype" type that is used +// to signal an NDArray or tensor of unknown type. +bool npcompTypeIsANumpyAnyDtype(MlirType t); + +/// Gets the "any dtype" type. +MlirType npcompAnyDtypeTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// NDArray type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an NdArray type. +bool npcompTypeIsANumpyNdArray(MlirType t); + +/// Gets a numpy.NdArray type that is unranked. +MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType); + +/// Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are +/// unknown. +MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, + MlirType elementType); + +/// Helper that gets an equivalent NdArrayType from a ShapedType. +MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType); + +/// Helper that converts an NdArrayType to a TensorType. +MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType); + +#ifdef __cplusplus +} +#endif + +#endif // NPCOMP_C_NUMPYTYPES_H diff --git a/include/npcomp-c/TorchTypes.h b/include/npcomp-c/TorchTypes.h new file mode 100644 index 000000000..06be691a0 --- /dev/null +++ b/include/npcomp-c/TorchTypes.h @@ -0,0 +1,143 @@ +//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef NPCOMP_C_TORCHTYPES_H +#define NPCOMP_C_TORCHTYPES_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// torch.nn.Module type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a torch.nn.Module type +bool npcompTypeIsATorchNnModule(MlirType t); + +/// Gets the !torch.nn.Module type of the specified class. +MlirType npcompTorchNnModuleTypeGet(MlirContext context, + MlirStringRef className); + +//===----------------------------------------------------------------------===// +// torch.optional type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.optional type +bool npcompTypeIsATorchOptional(MlirType t); + +/// Gets the !torch.optional type with subtype T. +MlirType npcompTorchOptionalTypeGet(MlirType containedType); + +//===----------------------------------------------------------------------===// +// torch.list type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.list type +bool npcompTypeIsATorchList(MlirType t); + +/// Gets the !torch.list type with contained T. +MlirType npcompTorchListTypeGet(MlirType containedType); + +//===----------------------------------------------------------------------===// +// torch.Device type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.Device type +bool npcompTypeIsATorchDevice(MlirType t); + +/// Gets the !torch.Device type. +MlirType npcompTorchDeviceTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.LinearParams type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.LinearParams type +bool npcompTypeIsATorchLinearParams(MlirType t); + +/// Gets the !torch.LinearParams type. +MlirType npcompTorchLinearParamsTypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.qint8 type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.qint8 type +bool npcompTypeIsATorchQInt8(MlirType t); + +/// Gets the !torch.qint8 type. +MlirType npcompTorchQInt8TypeGet(MlirContext context); + +//===----------------------------------------------------------------------===// +// torch.tensor type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.tensor type +bool npcompTypeIsATorchNonValueTensor(MlirType t); + +/// Gets a !torch.tensor type. +/// +/// - `optionalSizes` is allowed to be null, meaning that no size +/// information is present (and `numSizes` is ignored in that case). - +/// `optionalDtype` is allowed to be null, meaning that no dtype +/// information is present. +MlirType npcompTorchNonValueTensorTypeGet(MlirContext context, + intptr_t numSizes, + const int64_t *optionalSizes, + MlirType optionalDtype); + +/// Gets the !torch.tensor type with the least static information. +MlirType +npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context); + +/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. +MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type); + +//===----------------------------------------------------------------------===// +// torch.vtensor type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.vtensor type +bool npcompTypeIsATorchValueTensor(MlirType t); + +/// Gets a !torch.vtensor type. +/// +/// - `optionalSizes` is allowed to be null, meaning that no size +/// information is present (and `numSizes` is ignored in that case). +/// - `optionalDtype` is allowed to be null, meaning that no dtype +/// information is present. +MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes, + const int64_t *optionalSizes, + MlirType optionalDtype); + +/// Gets the !torch.tensor type with the least static information. +MlirType +npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context); + +/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. +MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type); + +//===----------------------------------------------------------------------===// +// !torch.none type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is a !torch.none type +bool npcompTypeIsATorchNone(MlirType t); + +/// Gets the !torch.none type. +MlirType npcompTorchNoneTypeGet(MlirContext context); + +#ifdef __cplusplus +} +#endif + +#endif // NPCOMP_C_TORCHTYPES_H diff --git a/include/npcomp-c/Types.h b/include/npcomp-c/Types.h deleted file mode 100644 index 541589f9e..000000000 --- a/include/npcomp-c/Types.h +++ /dev/null @@ -1,246 +0,0 @@ -/*===-- npcomp-c/Types.h - NPComp custom types --------------------*- C -*-===*\ -|* *| -|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *| -|* Exceptions. *| -|* See https://llvm.org/LICENSE.txt for license information. *| -|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *| -|* *| -\*===----------------------------------------------------------------------===*/ - -#ifndef NPCOMP_C_TYPES_H -#define NPCOMP_C_TYPES_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/*============================================================================*/ -/* Any dtype type. */ -/*============================================================================*/ - -/** Checks whether the given type is the special "any dtype" type that is used - * to signal an NDArray or tensor of unknown type. */ -int npcompTypeIsAAnyDtype(MlirType t); - -/** Gets the "any dtype" type. */ -MlirType npcompAnyDtypeTypeGet(MlirContext context); - -/*============================================================================*/ -/* Bool type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "bool" type. */ -int npcompTypeIsABool(MlirType t); - -/** Gets the Python "bool" type. */ -MlirType npcompBoolTypeGet(MlirContext context); - -/*============================================================================*/ -/* Bytes type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "bytes" type. */ -int npcompTypeIsABytes(MlirType t); - -/** Gets the Python "bytes" type. */ -MlirType npcompBytesTypeGet(MlirContext context); - -/*============================================================================*/ -/* Dict type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "dict" type. */ -int npcompTypeIsADict(MlirType t); - -/** Gets the generic Python "dict" type. */ -MlirType npcompDictTypeGet(MlirContext context); - -/*============================================================================*/ -/* List type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "list" type. */ -int npcompTypeIsABasicpyList(MlirType t); - -/** Gets the generic Python "list" type. */ -MlirType npcompBasicpyListTypeGet(MlirContext context); - -/*============================================================================*/ -/* NDArray type. */ -/*============================================================================*/ - -/** Checks whether the given type is an NdArray type. */ -int npcompTypeIsANdArray(MlirType t); - -/** Gets a numpy.NdArray type that is unranked. */ -MlirType npcompNdArrayTypeGetUnranked(MlirType elementType); - -/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are - * unknown. */ -MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, - MlirType elementType); - -/// Helper that gets an equivalent NdArrayType from a ShapedType. -MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType); - -/// Helper that converts an NdArrayType to a TensorType. -MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType); - -/*============================================================================*/ -/* !basicpy.NoneType type. */ -/*============================================================================*/ - -/** Checks whether the given type is a `!basicpy.NoneType`. */ -int npcompTypeIsABasicpyNone(MlirType t); - -/** Gets the `!basicpy.NoneType` type. */ -MlirType npcompBasicpyNoneTypeGet(MlirContext context); - -/*============================================================================*/ -/* SlotObject type. */ -/*============================================================================*/ - -MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className, - intptr_t slotTypeCount, - const MlirType *slotTypes); - -/*============================================================================*/ -/* Tuple type. */ -/*============================================================================*/ - -/** Checks whether the given type is the special "any dtype" type that is used - * to signal an NDArray or tensor of unknown type. */ -int npcompTypeIsATuple(MlirType t); - -/** Gets the generic Python "tuple" type. */ -MlirType npcompTupleTypeGet(MlirContext context); - -/*============================================================================*/ -/* torch.nn.Module type. */ -/*============================================================================*/ - -/** Checks whether the given type is a torch.nn.Module type */ -int npcompTypeIsANnModule(MlirType t); - -/** Gets the !torch.nn.Module type of the specified class. */ -MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className); - -/*============================================================================*/ -/* torch.optional type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.optional type */ -int npcompTypeIsAOptional(MlirType t); - -/** Gets the !torch.optional type with subtype T. */ -MlirType npcompOptionalTypeGet(MlirType containedType); - -/*============================================================================*/ -/* torch.list type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.list type */ -int npcompTypeIsAList(MlirType t); - -/** Gets the !torch.list type with contained T. */ -MlirType npcompListTypeGet(MlirType containedType); - -/*============================================================================*/ -/* torch.Device type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.Device type */ -int npcompTypeIsADevice(MlirType t); - -/** Gets the !torch.Device type. */ -MlirType npcompDeviceTypeGet(MlirContext context); - -/*============================================================================*/ -/* torch.LinearParams type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.LinearParams type */ -int npcompTypeIsALinearParams(MlirType t); - -/** Gets the !torch.LinearParams type. */ -MlirType npcompLinearParamsTypeGet(MlirContext context); - -/*============================================================================*/ -/* torch.qint8 type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.qint8 type */ -int npcompTypeIsAQInt8(MlirType t); - -/** Gets the !torch.qint8 type. */ -MlirType npcompQInt8TypeGet(MlirContext context); - -/*============================================================================*/ -/* torch.tensor type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.tensor type */ -int npcompTypeIsANonValueTensor(MlirType t); - -/** Gets a !torch.tensor type. - * - * - `optionalSizes` is allowed to be null, meaning that no size information is - * present (and `numSizes` is ignored in that case). - * - `optionalDtype` is allowed to be null, meaning that no dtype information is - * present. - * - */ -MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes, - const int64_t *optionalSizes, - MlirType optionalDtype); - -/** Gets the !torch.tensor type with the least static information. */ -MlirType -npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context); - -/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */ -MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type); - -/*============================================================================*/ -/* torch.vtensor type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.vtensor type */ -int npcompTypeIsAValueTensor(MlirType t); - -/** Gets a !torch.vtensor type. - * - * - `optionalSizes` is allowed to be null, meaning that no size information is - * present (and `numSizes` is ignored in that case). - * - `optionalDtype` is allowed to be null, meaning that no dtype information is - * present. - * - */ -MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes, - const int64_t *optionalSizes, - MlirType optionalDtype); - -/** Gets the !torch.tensor type with the least static information. */ -MlirType -npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context); - -/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */ -MlirType npcompValueTensorTypeGetFromShaped(MlirType type); - -/*============================================================================*/ -/* !torch.none type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.none type */ -int npcompTypeIsATorchNone(MlirType t); - -/** Gets the !torch.none type. */ -MlirType npcompTorchNoneTypeGet(MlirContext context); - -#ifdef __cplusplus -} -#endif - -#endif // NPCOMP_C_TYPES_H diff --git a/lib/CAPI/BasicpyTypes.cpp b/lib/CAPI/BasicpyTypes.cpp new file mode 100644 index 000000000..4548d3412 --- /dev/null +++ b/lib/CAPI/BasicpyTypes.cpp @@ -0,0 +1,116 @@ +//===- BasicpyTypes.cpp - C Interface for basicpy types -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp-c/BasicpyTypes.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" +#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +//===----------------------------------------------------------------------===// +// Bool type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsABasicpyBool(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompBasicpyBoolTypeGet(MlirContext context) { + return wrap(Basicpy::BoolType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// Bytes type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsABasicpyBytes(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompBasicpyBytesTypeGet(MlirContext context) { + return wrap(Basicpy::BytesType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// Dict type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is the Python "dict" type. */ +bool npcompTypeIsABasicpyDict(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the generic Python "dict" type. */ +MlirType npcompBasicpyDictTypeGet(MlirContext context) { + return wrap(Basicpy::DictType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// List type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is the Python "list" type. */ +bool npcompTypeIsABasicpyList(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the generic Python "dict" type. */ +MlirType npcompBasicpyListTypeGet(MlirContext context) { + return wrap(Basicpy::ListType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// !basicpy.NoneType type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is a `!basicpy.NoneType`. */ +bool npcompTypeIsANone(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the `!basicpy.NoneType` type. */ +MlirType npcompBasicpyNoneTypeGet(MlirContext context) { + return wrap(Basicpy::NoneType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// SlotObject type. +//===----------------------------------------------------------------------===// + +MlirType npcompBasicPySlotObjectTypeGet(MlirContext context, + MlirStringRef className, + intptr_t slotTypeCount, + const MlirType *slotTypes) { + MLIRContext *cppContext = unwrap(context); + auto classNameAttr = StringAttr::get(cppContext, unwrap(className)); + SmallVector slotTypesCpp; + slotTypesCpp.resize(slotTypeCount); + for (intptr_t i = 0; i < slotTypeCount; ++i) { + slotTypesCpp[i] = unwrap(slotTypes[i]); + } + return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp)); +} + +//===----------------------------------------------------------------------===// +// Tuple type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is the special "any dtype" type that is used + * to signal an NDArray or tensor of unknown type. */ +bool npcompTypeIsABasicpyTuple(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the "any dtype" type. */ +MlirType npcompBasicpyTupleTypeGet(MlirContext context) { + return wrap(Basicpy::TupleType::get(unwrap(context))); +} diff --git a/lib/CAPI/CMakeLists.txt b/lib/CAPI/CMakeLists.txt index 351e1090d..15300412e 100644 --- a/lib/CAPI/CMakeLists.txt +++ b/lib/CAPI/CMakeLists.txt @@ -7,7 +7,9 @@ set(LLVM_LINK_COMPONENTS add_npcomp_library(NPCOMPCAPI InitLLVM.cpp Registration.cpp - Types.cpp + BasicpyTypes.cpp + NumpyTypes.cpp + TorchTypes.cpp LINK_LIBS PUBLIC MLIRExecutionEngine diff --git a/lib/CAPI/NumpyTypes.cpp b/lib/CAPI/NumpyTypes.cpp new file mode 100644 index 000000000..a5b811dca --- /dev/null +++ b/lib/CAPI/NumpyTypes.cpp @@ -0,0 +1,56 @@ +//===- NumpyTypes.cpp - C Interface for numpy types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp-c/NumpyTypes.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" +#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +//===----------------------------------------------------------------------===// +// !numpy.any_dtype type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsANumpyAnyDtype(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompAnyDtypeTypeGet(MlirContext context) { + return wrap(Numpy::AnyDtypeType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// NDArray type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsANumpyNdArray(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType) { + return wrap(Numpy::NdArrayType::get(unwrap(elementType))); +} + +MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, + MlirType elementType) { + llvm::ArrayRef shapeArray(shape, rank); + return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray)); +} + +MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType) { + return wrap(Numpy::NdArrayType::getFromShapedType( + unwrap(shapedType).cast())); +} + +MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType) { + return wrap(unwrap(ndarrayType).cast().toTensorType()); +} diff --git a/lib/CAPI/TorchTypes.cpp b/lib/CAPI/TorchTypes.cpp new file mode 100644 index 000000000..f689b04d2 --- /dev/null +++ b/lib/CAPI/TorchTypes.cpp @@ -0,0 +1,165 @@ +//===- TorchTypes.cpp - C Interface for torch types -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "npcomp-c/TorchTypes.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/BuiltinTypes.h" +#include "npcomp/Dialect/Torch/IR/TorchTypes.h" + +using namespace mlir; +using namespace mlir::NPCOMP; + +//===----------------------------------------------------------------------===// +// torch.nn.Module type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is a torch.nn.Module type */ +bool npcompTypeIsATorchNnModule(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the torch.nn.Module type of the specified class. */ +MlirType npcompTorchNnModuleTypeGet(MlirContext context, + MlirStringRef className) { + return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); +} + +//===----------------------------------------------------------------------===// +// torch.optional type. +//===----------------------------------------------------------------------===// + +/** Checks whether the given type is a !torch.optional type */ +bool npcompTypeIsATorchOptional(MlirType t) { + return unwrap(t).isa(); +} + +/** Gets the !torch.optional type with subtype T. */ +MlirType npcompTorchOptionalTypeGet(MlirType containedType) { + return wrap(Torch::OptionalType::get(unwrap(containedType))); +} + +//===----------------------------------------------------------------------===// +// torch.list type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchList(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchListTypeGet(MlirType containedType) { + return wrap(Torch::ListType::get(unwrap(containedType))); +} + +//===----------------------------------------------------------------------===// +// torch.Device type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchDevice(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchDeviceTypeGet(MlirContext context) { + return wrap(Torch::DeviceType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// torch.LinearParams type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchLinearParams(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchLinearParamsTypeGet(MlirContext context) { + return wrap(Torch::LinearParamsType::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// torch.qint8 type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchQInt8(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchQInt8TypeGet(MlirContext context) { + return wrap(Torch::QInt8Type::get(unwrap(context))); +} + +//===----------------------------------------------------------------------===// +// torch.tensor type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchNonValueTensor(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchNonValueTensorTypeGet(MlirContext context, + intptr_t numSizes, + const int64_t *optionalSizes, + MlirType optionalDtype) { + Optional> optionalSizesArrayRef = None; + if (optionalSizes) + optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); + return wrap(Torch::NonValueTensorType::get( + unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); +} + +MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation( + MlirContext context) { + return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation( + unwrap(context))); +} + +MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) { + return wrap(Torch::NonValueTensorType::getFromShaped( + unwrap(type).cast())); +} + +//===----------------------------------------------------------------------===// +// torch.vtensor type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchValueTensor(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes, + const int64_t *optionalSizes, + MlirType optionalDtype) { + Optional> optionalSizesArrayRef = None; + if (optionalSizes) + optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); + return wrap(Torch::ValueTensorType::get( + unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); +} + +MlirType +npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) { + return wrap( + Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context))); +} + +MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) { + return wrap( + Torch::ValueTensorType::getFromShaped(unwrap(type).cast())); +} + +//===----------------------------------------------------------------------===// +// torch.none type. +//===----------------------------------------------------------------------===// + +bool npcompTypeIsATorchNone(MlirType t) { + return unwrap(t).isa(); +} + +MlirType npcompTorchNoneTypeGet(MlirContext context) { + return wrap(Torch::NoneType::get(unwrap(context))); +} diff --git a/lib/CAPI/Types.cpp b/lib/CAPI/Types.cpp deleted file mode 100644 index 5e7ff1a64..000000000 --- a/lib/CAPI/Types.cpp +++ /dev/null @@ -1,303 +0,0 @@ -//===- Types.cpp - C Interface for NPComp types ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "npcomp-c/Types.h" - -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" -#include "mlir/IR/BuiltinTypes.h" -#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" -#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" -#include "npcomp/Dialect/Torch/IR/TorchTypes.h" - -using namespace mlir; -using namespace mlir::NPCOMP; - -/*============================================================================*/ -/* Any dtype type. */ -/*============================================================================*/ - -int npcompTypeIsAAnyDtype(MlirType t) { - return unwrap(t).isa(); -} - -MlirType npcompAnyDtypeTypeGet(MlirContext context) { - return wrap(Numpy::AnyDtypeType::get(unwrap(context))); -} - -/*============================================================================*/ -/* Bool type. */ -/*============================================================================*/ - -int npcompTypeIsABool(MlirType t) { return unwrap(t).isa(); } - -MlirType npcompBoolTypeGet(MlirContext context) { - return wrap(Basicpy::BoolType::get(unwrap(context))); -} - -/*============================================================================*/ -/* Bytes type. */ -/*============================================================================*/ - -int npcompTypeIsABytes(MlirType t) { - return unwrap(t).isa(); -} - -MlirType npcompBytesTypeGet(MlirContext context) { - return wrap(Basicpy::BytesType::get(unwrap(context))); -} - -/*============================================================================*/ -/* Dict type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "dict" type. */ -int npcompTypeIsADict(MlirType t) { return unwrap(t).isa(); } - -/** Gets the generic Python "dict" type. */ -MlirType npcompDictTypeGet(MlirContext context) { - return wrap(Basicpy::DictType::get(unwrap(context))); -} - -/*============================================================================*/ -/* List type. */ -/*============================================================================*/ - -/** Checks whether the given type is the Python "list" type. */ -int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa(); } - -/** Gets the generic Python "dict" type. */ -MlirType npcompBasicpyListTypeGet(MlirContext context) { - return wrap(Basicpy::ListType::get(unwrap(context))); -} - -/*============================================================================*/ -/* NDArray type. */ -/*============================================================================*/ - -int npcompTypeIsANdArray(MlirType t) { - return unwrap(t).isa(); -} - -MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) { - return wrap(Numpy::NdArrayType::get(unwrap(elementType))); -} - -MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape, - MlirType elementType) { - llvm::ArrayRef shapeArray(shape, rank); - return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray)); -} - -MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) { - return wrap(Numpy::NdArrayType::getFromShapedType( - unwrap(shapedType).cast())); -} - -MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) { - return wrap(unwrap(ndarrayType).cast().toTensorType()); -} - -/*============================================================================*/ -/* !basicpy.NoneType type. */ -/*============================================================================*/ - -/** Checks whether the given type is a `!basicpy.NoneType`. */ -int npcompTypeIsANone(MlirType t) { return unwrap(t).isa(); } - -/** Gets the `!basicpy.NoneType` type. */ -MlirType npcompBasicpyNoneTypeGet(MlirContext context) { - return wrap(Basicpy::NoneType::get(unwrap(context))); -} - -/*============================================================================*/ -/* SlotObject type. */ -/*============================================================================*/ - -MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className, - intptr_t slotTypeCount, - const MlirType *slotTypes) { - MLIRContext *cppContext = unwrap(context); - auto classNameAttr = StringAttr::get(cppContext, unwrap(className)); - SmallVector slotTypesCpp; - slotTypesCpp.resize(slotTypeCount); - for (intptr_t i = 0; i < slotTypeCount; ++i) { - slotTypesCpp[i] = unwrap(slotTypes[i]); - } - return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp)); -} - -/*============================================================================*/ -/* Tuple type. */ -/*============================================================================*/ - -/** Checks whether the given type is the special "any dtype" type that is used - * to signal an NDArray or tensor of unknown type. */ -int npcompTypeIsATuple(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the "any dtype" type. */ -MlirType npcompTupleTypeGet(MlirContext context) { - return wrap(Basicpy::TupleType::get(unwrap(context))); -} - -/*============================================================================*/ -/* torch.nn.Module type. */ -/*============================================================================*/ - -/** Checks whether the given type is a torch.nn.Module type */ -int npcompTypeIsANnModule(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the torch.nn.Module type of the specified class. */ -MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) { - return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className))); -} - -/*============================================================================*/ -/* torch.optional type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.optional type */ -int npcompTypeIsAOptional(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.optional type with subtype T. */ -MlirType npcompOptionalTypeGet(MlirType containedType) { - return wrap(Torch::OptionalType::get(unwrap(containedType))); -} - -/*============================================================================*/ -/* torch.list type. */ -/*============================================================================*/ -/** Checks whether the given type is a !torch.list type */ -int npcompTypeIsAList(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.List type with contained T. */ -MlirType npcompListTypeGet(MlirType containedType) { - return wrap(Torch::ListType::get(unwrap(containedType))); -} - -/*============================================================================*/ -/* torch.Device type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.Device type */ -int npcompTypeIsADevice(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.Device type. */ -MlirType npcompDeviceTypeGet(MlirContext context) { - return wrap(Torch::DeviceType::get(unwrap(context))); -} - -/*============================================================================*/ -/* torch.LinearParams type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.LinearParams type */ -int npcompTypeIsALinearParams(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.LinearParams type. */ -MlirType npcompLinearParamsTypeGet(MlirContext context) { - return wrap(Torch::LinearParamsType::get(unwrap(context))); -} - -/*============================================================================*/ -/* torch.qint8 type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.qint8 type */ -int npcompTypeIsAQInt8(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.qint8 type. */ -MlirType npcompQInt8TypeGet(MlirContext context) { - return wrap(Torch::QInt8Type::get(unwrap(context))); -} - -/*============================================================================*/ -/* torch.tensor type. */ -/*============================================================================*/ - -int npcompTypeIsANonValueTensor(MlirType t) { - return unwrap(t).isa(); -} - -MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes, - const int64_t *optionalSizes, - MlirType optionalDtype) { - Optional> optionalSizesArrayRef = None; - if (optionalSizes) - optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); - return wrap(Torch::NonValueTensorType::get( - unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); -} - -MlirType -npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context) { - return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation( - unwrap(context))); -} - -MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type) { - return wrap(Torch::NonValueTensorType::getFromShaped( - unwrap(type).cast())); -} - -/*============================================================================*/ -/* torch.vtensor type. */ -/*============================================================================*/ - -int npcompTypeIsAValueTensor(MlirType t) { - return unwrap(t).isa(); -} - -MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes, - const int64_t *optionalSizes, - MlirType optionalDtype) { - Optional> optionalSizesArrayRef = None; - if (optionalSizes) - optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes); - return wrap(Torch::ValueTensorType::get( - unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype))); -} - -MlirType -npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context) { - return wrap( - Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context))); -} - -MlirType npcompValueTensorTypeGetFromShaped(MlirType type) { - return wrap( - Torch::ValueTensorType::getFromShaped(unwrap(type).cast())); -} - -/*============================================================================*/ -/* torch.none type. */ -/*============================================================================*/ - -/** Checks whether the given type is a !torch.none type */ -int npcompTypeIsATorchNone(MlirType t) { - return unwrap(t).isa(); -} - -/** Gets the !torch.none type. */ -MlirType npcompTorchNoneTypeGet(MlirContext context) { - return wrap(Torch::NoneType::get(unwrap(context))); -} diff --git a/python/NpcompModule.cpp b/python/NpcompModule.cpp index 46c1e3f96..db7281283 100644 --- a/python/NpcompModule.cpp +++ b/python/NpcompModule.cpp @@ -12,9 +12,10 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Diagnostics.h" +#include "npcomp-c/BasicpyTypes.h" #include "npcomp-c/InitLLVM.h" +#include "npcomp-c/NumpyTypes.h" #include "npcomp-c/Registration.h" -#include "npcomp-c/Types.h" #include "npcomp/Python/PybindUtils.h" #ifdef NPCOMP_ENABLE_REFJIT @@ -27,21 +28,21 @@ MlirType shapedToNdArrayArrayType(MlirType shaped_type) { if (!mlirTypeIsAShaped(shaped_type)) { throw py::raiseValueError("type is not a shaped type"); } - return npcompNdArrayTypeGetFromShaped(shaped_type); + return npcompNumpyNdArrayTypeGetFromShaped(shaped_type); } MlirType ndarrayToTensorType(MlirType ndarray_type) { - if (!npcompTypeIsANdArray(ndarray_type)) { + if (!npcompTypeIsANumpyNdArray(ndarray_type)) { throw py::raiseValueError("type is not an ndarray type"); } - return npcompNdArrayTypeToTensor(ndarray_type); + return npcompNumpyNdArrayTypeToTensor(ndarray_type); } MlirType slotObjectType(MlirContext context, const std::string &className, const std::vector &slotTypes) { MlirStringRef classNameSr{className.data(), className.size()}; - return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(), - slotTypes.data()); + return ::npcompBasicPySlotObjectTypeGet(context, classNameSr, + slotTypes.size(), slotTypes.data()); } // TODO: Move this upstream. diff --git a/test/CAPI/ir.c b/test/CAPI/ir.c index 29305239a..8bb9d0020 100644 --- a/test/CAPI/ir.c +++ b/test/CAPI/ir.c @@ -12,8 +12,9 @@ #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "npcomp-c/BasicpyTypes.h" +#include "npcomp-c/NumpyTypes.h" #include "npcomp-c/Registration.h" -#include "npcomp-c/Types.h" #include #include @@ -24,30 +25,31 @@ // Dumps an instance of all NPComp types. static int printStandardTypes(MlirContext ctx) { // Bool type. - MlirType boolType = npcompBoolTypeGet(ctx); - if (!npcompTypeIsABool(boolType)) + MlirType boolType = npcompBasicpyBoolTypeGet(ctx); + if (!npcompTypeIsABasicpyBool(boolType)) return 1; mlirTypeDump(boolType); fprintf(stderr, "\n"); // Bytes type. - MlirType bytesType = npcompBytesTypeGet(ctx); - if (!npcompTypeIsABytes(bytesType)) + MlirType bytesType = npcompBasicpyBytesTypeGet(ctx); + if (!npcompTypeIsABasicpyBytes(bytesType)) return 1; mlirTypeDump(bytesType); fprintf(stderr, "\n"); // Any dtype. MlirType anyDtype = npcompAnyDtypeTypeGet(ctx); - if (!npcompTypeIsAAnyDtype(anyDtype)) + if (!npcompTypeIsANumpyAnyDtype(anyDtype)) return 2; mlirTypeDump(anyDtype); fprintf(stderr, "\n"); // Ranked NdArray. int64_t fourDim = 4; - MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType); - if (!npcompTypeIsANdArray(rankedNdArray)) + MlirType rankedNdArray = + npcompNumpyNdArrayTypeGetRanked(1, &fourDim, boolType); + if (!npcompTypeIsANumpyNdArray(rankedNdArray)) return 3; mlirTypeDump(rankedNdArray); fprintf(stderr, "\n");