mirror of https://github.com/llvm/torch-mlir
Make C API files more consistent
- Make consistent with MLIR Core - Use `//` or `///` comments. - Use `bool` type for booleans - No duplicated comments in .cpp files - Split types into separate files `{Basicpy,Numpy,Torch}Types.h` - Add dialect prefix consistently to C API symbols. We have lots of similarly named types (e.g. "list" type in basicpy and torch).pull/222/head
parent
db282fd1b4
commit
6b2424512b
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/TorchTypes.h"
|
||||||
#include "npcomp/Python/PybindUtils.h"
|
#include "npcomp/Python/PybindUtils.h"
|
||||||
|
|
||||||
#include <ATen/core/function_schema.h>
|
#include <ATen/core/function_schema.h>
|
||||||
|
@ -512,9 +512,8 @@ MlirType AcapController::mapIValueToMlirType(MlirLocation loc,
|
||||||
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
return mlirIntegerTypeGet(funcBuilder->getContext(), 1);
|
||||||
}
|
}
|
||||||
if (ival.isList()) {
|
if (ival.isList()) {
|
||||||
return npcompListTypeGet(
|
return npcompTorchListTypeGet(
|
||||||
typeMapper.mapFromTorchType(
|
typeMapper.mapFromTorchType(loc, ival.toList().elementType()));
|
||||||
loc, ival.toList().elementType()));
|
|
||||||
}
|
}
|
||||||
if (ival.isNone()) {
|
if (ival.isNone()) {
|
||||||
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
return npcompTorchNoneTypeGet(funcBuilder->getContext());
|
||||||
|
@ -530,7 +529,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||||
funcBuilder->getEntryBlock(), "torch.tensor", loc,
|
funcBuilder->getEntryBlock(), "torch.tensor", loc,
|
||||||
npcompNonValueTensorTypeGetFromShaped(
|
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||||
mlirAttributeGetType(denseElements)),
|
mlirAttributeGetType(denseElements)),
|
||||||
toMlirNamedAttribute("value", denseElements));
|
toMlirNamedAttribute("value", denseElements));
|
||||||
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
MlirValue tensorValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ MlirValue FuncBuilder::getGeneralConstant(MlirLocation loc,
|
||||||
|
|
||||||
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
MlirValue FuncBuilder::buildList(MlirLocation loc, MlirType elementType,
|
||||||
std::vector<MlirValue> &elements) {
|
std::vector<MlirValue> &elements) {
|
||||||
MlirType resultType = npcompListTypeGet(elementType);
|
MlirType resultType = npcompTorchListTypeGet(elementType);
|
||||||
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
OperationStateHolder state{"torch.prim.ListConstruct", loc};
|
||||||
mlirOperationStateAddResults(state, 1, &resultType);
|
mlirOperationStateAddResults(state, 1, &resultType);
|
||||||
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
mlirOperationStateAddOperands(state, elements.size(), elements.data());
|
||||||
|
|
|
@ -16,7 +16,8 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
#include "caffe2/core/scope_guard.h"
|
#include "caffe2/core/scope_guard.h"
|
||||||
#include "ATen/native/quantized/cpu/packed_params.h"
|
#include "ATen/native/quantized/cpu/packed_params.h"
|
||||||
|
@ -170,7 +171,7 @@ IValueImporter::importModule(torch::jit::Module currentModule) {
|
||||||
|
|
||||||
MlirOperation nnModule = createMlirOperation(
|
MlirOperation nnModule = createMlirOperation(
|
||||||
"torch.nn_module", loc,
|
"torch.nn_module", loc,
|
||||||
npcompNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
npcompTorchNnModuleTypeGet(context, toMlirStringRef(moduleTypeName)),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
MlirRegion nnModuleRegion = mlirOperationGetRegion(nnModule, 0);
|
||||||
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
mlirRegionAppendOwnedBlock(nnModuleRegion, mlirBlockCreate(0, nullptr));
|
||||||
|
@ -240,7 +241,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
MlirLocation loc = mlirLocationUnknownGet(context);
|
MlirLocation loc = mlirLocationUnknownGet(context);
|
||||||
|
|
||||||
if (ivalue.isBool()) {
|
if (ivalue.isBool()) {
|
||||||
MlirType type = npcompBoolTypeGet(context);
|
MlirType type = npcompBasicpyBoolTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "basicpy.bool_constant", loc, type,
|
importBlock, "basicpy.bool_constant", loc, type,
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute("value",
|
||||||
|
@ -269,11 +270,11 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
for (const c10::IValue &elem : list) {
|
for (const c10::IValue &elem : list) {
|
||||||
elems.push_back(importIValue(elem));
|
elems.push_back(importIValue(elem));
|
||||||
}
|
}
|
||||||
MlirOperation operation =
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
createMlirOperationAtEnd(importBlock, "torch.prim.ListConstruct", loc,
|
importBlock, "torch.prim.ListConstruct", loc,
|
||||||
npcompListTypeGet(
|
npcompTorchListTypeGet(
|
||||||
typeMapper.mapFromTorchType(
|
typeMapper.mapFromTorchType(loc, list.elementType())),
|
||||||
loc, list.elementType())), elems);
|
elems);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isTuple()) {
|
if (ivalue.isTuple()) {
|
||||||
|
@ -284,7 +285,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation =
|
MlirOperation operation =
|
||||||
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
createMlirOperationAtEnd(importBlock, "basicpy.build_tuple", loc,
|
||||||
npcompTupleTypeGet(context), elems);
|
npcompBasicpyTupleTypeGet(context), elems);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
if (ivalue.isTensor()) {
|
if (ivalue.isTensor()) {
|
||||||
|
@ -294,7 +295,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
return importModule(ivalue.toModule());
|
return importModule(ivalue.toModule());
|
||||||
}
|
}
|
||||||
if (ivalue.isString()) {
|
if (ivalue.isString()) {
|
||||||
MlirType type = npcompBytesTypeGet(context);
|
MlirType type = npcompBasicpyBytesTypeGet(context);
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "basicpy.bytes_constant", loc, type,
|
importBlock, "basicpy.bytes_constant", loc, type,
|
||||||
toMlirNamedAttribute(
|
toMlirNamedAttribute(
|
||||||
|
@ -325,7 +326,7 @@ MlirValue IValueImporter::rawImportIValue(c10::IValue ivalue) {
|
||||||
}
|
}
|
||||||
MlirOperation operation = createMlirOperationAtEnd(
|
MlirOperation operation = createMlirOperationAtEnd(
|
||||||
importBlock, "torch.linear_params.create", loc,
|
importBlock, "torch.linear_params.create", loc,
|
||||||
npcompLinearParamsTypeGet(context), weightValue, biasValue);
|
npcompTorchLinearParamsTypeGet(context), weightValue, biasValue);
|
||||||
return mlirOperationGetResult(operation, 0);
|
return mlirOperationGetResult(operation, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -345,7 +346,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||||
MlirOperation tensorOp =
|
MlirOperation tensorOp =
|
||||||
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
||||||
npcompNonValueTensorTypeGetFromShaped(
|
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||||
mlirAttributeGetType(denseElements)),
|
mlirAttributeGetType(denseElements)),
|
||||||
toMlirNamedAttribute("value", denseElements));
|
toMlirNamedAttribute("value", denseElements));
|
||||||
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
|
||||||
|
@ -360,7 +361,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
||||||
// compiler stages that are building a statically modeled quantization
|
// compiler stages that are building a statically modeled quantization
|
||||||
// representation will need to convert this to their representation.
|
// representation will need to convert this to their representation.
|
||||||
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
std::vector<int64_t> shape(tensor.sizes().begin(), tensor.sizes().end());
|
||||||
MlirType quantizedTensorType = npcompNonValueTensorTypeGet(
|
MlirType quantizedTensorType = npcompTorchNonValueTensorTypeGet(
|
||||||
context, shape.size(), shape.data(),
|
context, shape.size(), shape.data(),
|
||||||
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
typeMapper.mapFromTorchScalarType(tensor.scalar_type()));
|
||||||
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
if (tensor.qscheme() == c10::kPerTensorAffine) {
|
||||||
|
@ -504,11 +505,11 @@ void IValueImporter::importCompilationUnit(torch::jit::CompilationUnit *cu) {
|
||||||
mlirLocationUnknownGet(context), *maybeDtype);
|
mlirLocationUnknownGet(context), *maybeDtype);
|
||||||
MlirType typeBound;
|
MlirType typeBound;
|
||||||
if (hasValueSemantics) {
|
if (hasValueSemantics) {
|
||||||
typeBound = npcompValueTensorTypeGet(context, shape.size(),
|
typeBound = npcompTorchValueTensorTypeGet(context, shape.size(),
|
||||||
shape.data(), dtype);
|
shape.data(), dtype);
|
||||||
} else {
|
} else {
|
||||||
typeBound = npcompNonValueTensorTypeGet(context, shape.size(),
|
typeBound = npcompTorchNonValueTensorTypeGet(context, shape.size(),
|
||||||
shape.data(), dtype);
|
shape.data(), dtype);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
MlirNamedAttribute typeBoundAttr = toMlirNamedAttribute(
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
@ -145,7 +146,8 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock) {
|
||||||
loc, appendToBlock),
|
loc, appendToBlock),
|
||||||
mlirRegionCreate());
|
mlirRegionCreate());
|
||||||
mapResults(node, operation);
|
mapResults(node, operation);
|
||||||
std::vector<MlirType> terminatorOperandTypes = {npcompBoolTypeGet(context)};
|
std::vector<MlirType> terminatorOperandTypes = {
|
||||||
|
npcompBasicpyBoolTypeGet(context)};
|
||||||
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
terminatorOperandTypes.insert(terminatorOperandTypes.end(),
|
||||||
resultTypes.begin(), resultTypes.end());
|
resultTypes.begin(), resultTypes.end());
|
||||||
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
auto createTerminator = [&](c10::ArrayRef<MlirValue> yieldedValues,
|
||||||
|
|
|
@ -10,7 +10,8 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -23,14 +24,14 @@ MlirOperation OpBuilder::createNoneConstant(MlirLocation loc) {
|
||||||
|
|
||||||
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
MlirOperation OpBuilder::createBoolConstant(MlirLocation loc, bool value) {
|
||||||
return createMlirOperation(
|
return createMlirOperation(
|
||||||
"basicpy.bool_constant", loc, npcompBoolTypeGet(context),
|
"basicpy.bool_constant", loc, npcompBasicpyBoolTypeGet(context),
|
||||||
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
toMlirNamedAttribute("value", mlirBoolAttrGet(context, value)));
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirOperation OpBuilder::createBytesConstant(MlirLocation loc,
|
MlirOperation OpBuilder::createBytesConstant(MlirLocation loc,
|
||||||
const std::string &value) {
|
const std::string &value) {
|
||||||
return createMlirOperation(
|
return createMlirOperation(
|
||||||
"basicpy.bytes_constant", loc, npcompBytesTypeGet(context),
|
"basicpy.bytes_constant", loc, npcompBasicpyBytesTypeGet(context),
|
||||||
toMlirNamedAttribute("value",
|
toMlirNamedAttribute("value",
|
||||||
mlirStringAttrGet(context, toMlirStringRef(value))));
|
mlirStringAttrGet(context, toMlirStringRef(value))));
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,8 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
#include "npcomp-c/Types.h"
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
using namespace torch_mlir;
|
using namespace torch_mlir;
|
||||||
|
|
||||||
|
@ -54,7 +55,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
case ScalarType::Long:
|
case ScalarType::Long:
|
||||||
return mlirIntegerTypeSignedGet(context, 64);
|
return mlirIntegerTypeSignedGet(context, 64);
|
||||||
case ScalarType::Bool:
|
case ScalarType::Bool:
|
||||||
return npcompBoolTypeGet(context);
|
return npcompBasicpyBoolTypeGet(context);
|
||||||
case ScalarType::Double:
|
case ScalarType::Double:
|
||||||
return mlirF64TypeGet(context);
|
return mlirF64TypeGet(context);
|
||||||
case ScalarType::Float:
|
case ScalarType::Float:
|
||||||
|
@ -64,7 +65,7 @@ MlirType TypeMapper::rawMapFromTorchScalarType(c10::ScalarType scalarType) {
|
||||||
case ScalarType::Half:
|
case ScalarType::Half:
|
||||||
return mlirF16TypeGet(context);
|
return mlirF16TypeGet(context);
|
||||||
case ScalarType::QInt8:
|
case ScalarType::QInt8:
|
||||||
return npcompQInt8TypeGet(context);
|
return npcompTorchQInt8TypeGet(context);
|
||||||
default: {
|
default: {
|
||||||
return {nullptr};
|
return {nullptr};
|
||||||
}
|
}
|
||||||
|
@ -103,7 +104,7 @@ static MlirType mapCustomClassType(MlirContext context, MlirLocation loc,
|
||||||
|
|
||||||
// Individually handle the custom classes that we know about.
|
// Individually handle the custom classes that we know about.
|
||||||
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
if (name == "__torch__.torch.classes.quantized.LinearPackedParamsBase") {
|
||||||
return npcompLinearParamsTypeGet(context);
|
return npcompTorchLinearParamsTypeGet(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// At this point, we know that the type is indeed a custom class type, but
|
// At this point, we know that the type is indeed a custom class type, but
|
||||||
|
@ -134,11 +135,11 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
auto &sizes = tensorType->symbolic_sizes();
|
auto &sizes = tensorType->symbolic_sizes();
|
||||||
if (!sizes.rank()) {
|
if (!sizes.rank()) {
|
||||||
// Unranked.
|
// Unranked.
|
||||||
return npcompNonValueTensorTypeGet(context,
|
return npcompTorchNonValueTensorTypeGet(context,
|
||||||
/*numSizes=*/0,
|
/*numSizes=*/0,
|
||||||
/*optionalSizes=*/nullptr,
|
/*optionalSizes=*/nullptr,
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
elementType);
|
elementType);
|
||||||
}
|
}
|
||||||
// Ranked with possibly dynamic dims.
|
// Ranked with possibly dynamic dims.
|
||||||
auto &symbolicShape = tensorType->symbolic_sizes();
|
auto &symbolicShape = tensorType->symbolic_sizes();
|
||||||
|
@ -148,10 +149,10 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
auto shapeSymbol = symbolicShape[i];
|
auto shapeSymbol = symbolicShape[i];
|
||||||
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
dims[i] = shapeSymbol.is_static() ? shapeSymbol.static_size() : -1;
|
||||||
}
|
}
|
||||||
return npcompNonValueTensorTypeGet(context, dims.size(),
|
return npcompTorchNonValueTensorTypeGet(context, dims.size(),
|
||||||
/*optionalSizes=*/dims.data(),
|
/*optionalSizes=*/dims.data(),
|
||||||
/*optionalDtype=*/
|
/*optionalDtype=*/
|
||||||
elementType);
|
elementType);
|
||||||
}
|
}
|
||||||
case TypeKind::ClassType: {
|
case TypeKind::ClassType: {
|
||||||
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
const c10::ClassTypePtr &classType = torchType->cast<c10::ClassType>();
|
||||||
|
@ -161,15 +162,14 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
}
|
}
|
||||||
auto maybeName = classType->name();
|
auto maybeName = classType->name();
|
||||||
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
std::string name = maybeName ? maybeName->qualifiedName() : "unnamed class";
|
||||||
return npcompNnModuleTypeGet(context, toMlirStringRef(name));
|
return npcompTorchNnModuleTypeGet(context, toMlirStringRef(name));
|
||||||
}
|
}
|
||||||
case TypeKind::FloatType: {
|
case TypeKind::FloatType: {
|
||||||
return mlirF64TypeGet(context);
|
return mlirF64TypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::OptionalType: {
|
case TypeKind::OptionalType: {
|
||||||
return npcompOptionalTypeGet(
|
return npcompTorchOptionalTypeGet(mapFromTorchType(
|
||||||
mapFromTorchType(
|
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
||||||
loc, torchType->cast<c10::OptionalType>()->getElementType()));
|
|
||||||
}
|
}
|
||||||
case TypeKind::IntType: {
|
case TypeKind::IntType: {
|
||||||
return mlirIntegerTypeGet(context, 64);
|
return mlirIntegerTypeGet(context, 64);
|
||||||
|
@ -178,22 +178,21 @@ MlirType TypeMapper::mapFromTorchType(MlirLocation loc,
|
||||||
return npcompTorchNoneTypeGet(context);
|
return npcompTorchNoneTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::BoolType: {
|
case TypeKind::BoolType: {
|
||||||
return npcompBoolTypeGet(context);
|
return npcompBasicpyBoolTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::ListType: {
|
case TypeKind::ListType: {
|
||||||
return npcompListTypeGet(
|
return npcompTorchListTypeGet(mapFromTorchType(
|
||||||
mapFromTorchType(
|
loc, torchType->cast<c10::ListType>()->getElementType()));
|
||||||
loc, torchType->cast<c10::ListType>()->getElementType()));
|
|
||||||
}
|
}
|
||||||
case TypeKind::TupleType: {
|
case TypeKind::TupleType: {
|
||||||
// TODO: Don't lose the element type information.
|
// TODO: Don't lose the element type information.
|
||||||
return npcompTupleTypeGet(context);
|
return npcompBasicpyTupleTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::StringType: {
|
case TypeKind::StringType: {
|
||||||
return npcompBytesTypeGet(context);
|
return npcompBasicpyBytesTypeGet(context);
|
||||||
}
|
}
|
||||||
case TypeKind::DeviceObjType: {
|
case TypeKind::DeviceObjType: {
|
||||||
return npcompDeviceTypeGet(context);
|
return npcompTorchDeviceTypeGet(context);
|
||||||
}
|
}
|
||||||
default: {
|
default: {
|
||||||
std::stringstream message;
|
std::stringstream message;
|
||||||
|
@ -216,8 +215,8 @@ MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||||
// just erase them and let the compiler decide.
|
// just erase them and let the compiler decide.
|
||||||
|
|
||||||
auto sizes = tensor.sizes();
|
auto sizes = tensor.sizes();
|
||||||
return npcompNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
return npcompTorchNonValueTensorTypeGet(context, sizes.size(), sizes.data(),
|
||||||
elementType);
|
elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType
|
MlirType
|
||||||
|
|
|
@ -0,0 +1,92 @@
|
||||||
|
//===-- npcomp-c/BasicpyTypes.h - C API for basicpy types ---------*- C -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
|
// Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_C_BASICPYTYPES_H
|
||||||
|
#define NPCOMP_C_BASICPYTYPES_H
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.BoolType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is the Python "bool" type.
|
||||||
|
bool npcompTypeIsABasicpyBool(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the Python "bool" type.
|
||||||
|
MlirType npcompBasicpyBoolTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.BytesType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is the Python "bytes" type.
|
||||||
|
bool npcompTypeIsABasicpyBytes(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the Python "bytes" type.
|
||||||
|
MlirType npcompBasicpyBytesTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.DictType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is the Python "dict" type.
|
||||||
|
bool npcompTypeIsABasicpyDict(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the generic Python "dict" type.
|
||||||
|
MlirType npcompBasicpyDictTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// List type
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is the Python "list" type.
|
||||||
|
bool npcompTypeIsABasicpyList(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the generic Python "list" type.
|
||||||
|
MlirType npcompBasicpyListTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.NoneType type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a `!basicpy.NoneType`.
|
||||||
|
bool npcompTypeIsABasicpyNone(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the `!basicpy.NoneType` type.
|
||||||
|
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SlotObject type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
|
||||||
|
MlirStringRef className,
|
||||||
|
intptr_t slotTypeCount,
|
||||||
|
const MlirType *slotTypes);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.TupleType
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a `!basicpy.TupleType`.
|
||||||
|
bool npcompTypeIsABasicpyTuple(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the generic Python "tuple" type.
|
||||||
|
MlirType npcompBasicpyTupleTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // NPCOMP_C_BASICPYTYPES_H
|
|
@ -0,0 +1,55 @@
|
||||||
|
//===-- npcomp-c/NumpyTypes.h - C API for numpy types -------------*- C -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
|
// Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_C_NUMPYTYPES_H
|
||||||
|
#define NPCOMP_C_NUMPYTYPES_H
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !numpy.any_dtype type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is the special "any dtype" type that is used
|
||||||
|
// to signal an NDArray or tensor of unknown type.
|
||||||
|
bool npcompTypeIsANumpyAnyDtype(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the "any dtype" type.
|
||||||
|
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NDArray type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is an NdArray type.
|
||||||
|
bool npcompTypeIsANumpyNdArray(MlirType t);
|
||||||
|
|
||||||
|
/// Gets a numpy.NdArray type that is unranked.
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType);
|
||||||
|
|
||||||
|
/// Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
||||||
|
/// unknown.
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||||
|
MlirType elementType);
|
||||||
|
|
||||||
|
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType);
|
||||||
|
|
||||||
|
/// Helper that converts an NdArrayType to a TensorType.
|
||||||
|
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // NPCOMP_C_NUMPYTYPES_H
|
|
@ -0,0 +1,143 @@
|
||||||
|
//===-- npcomp-c/TorchTypes.h - C API for torch types -------------*- C -*-===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||||
|
// Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_C_TORCHTYPES_H
|
||||||
|
#define NPCOMP_C_TORCHTYPES_H
|
||||||
|
|
||||||
|
#include "mlir-c/IR.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.nn.Module type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a torch.nn.Module type
|
||||||
|
bool npcompTypeIsATorchNnModule(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.nn.Module type of the specified class.
|
||||||
|
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||||
|
MlirStringRef className);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.optional type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.optional<T> type
|
||||||
|
bool npcompTypeIsATorchOptional(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.optional<T> type with subtype T.
|
||||||
|
MlirType npcompTorchOptionalTypeGet(MlirType containedType);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.list<T> type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.list<T> type
|
||||||
|
bool npcompTypeIsATorchList(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.list<T> type with contained T.
|
||||||
|
MlirType npcompTorchListTypeGet(MlirType containedType);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.Device type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.Device type
|
||||||
|
bool npcompTypeIsATorchDevice(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.Device type.
|
||||||
|
MlirType npcompTorchDeviceTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.LinearParams type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.LinearParams type
|
||||||
|
bool npcompTypeIsATorchLinearParams(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.LinearParams type.
|
||||||
|
MlirType npcompTorchLinearParamsTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.qint8 type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.qint8 type
|
||||||
|
bool npcompTypeIsATorchQInt8(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.qint8 type.
|
||||||
|
MlirType npcompTorchQInt8TypeGet(MlirContext context);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.tensor type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.tensor type
|
||||||
|
bool npcompTypeIsATorchNonValueTensor(MlirType t);
|
||||||
|
|
||||||
|
/// Gets a !torch.tensor type.
|
||||||
|
///
|
||||||
|
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||||
|
/// information is present (and `numSizes` is ignored in that case). -
|
||||||
|
/// `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
|
/// information is present.
|
||||||
|
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
|
intptr_t numSizes,
|
||||||
|
const int64_t *optionalSizes,
|
||||||
|
MlirType optionalDtype);
|
||||||
|
|
||||||
|
/// Gets the !torch.tensor type with the least static information.
|
||||||
|
MlirType
|
||||||
|
npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||||
|
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.vtensor type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.vtensor type
|
||||||
|
bool npcompTypeIsATorchValueTensor(MlirType t);
|
||||||
|
|
||||||
|
/// Gets a !torch.vtensor type.
|
||||||
|
///
|
||||||
|
/// - `optionalSizes` is allowed to be null, meaning that no size
|
||||||
|
/// information is present (and `numSizes` is ignored in that case).
|
||||||
|
/// - `optionalDtype` is allowed to be null, meaning that no dtype
|
||||||
|
/// information is present.
|
||||||
|
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||||
|
const int64_t *optionalSizes,
|
||||||
|
MlirType optionalDtype);
|
||||||
|
|
||||||
|
/// Gets the !torch.tensor type with the least static information.
|
||||||
|
MlirType
|
||||||
|
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
||||||
|
|
||||||
|
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
|
||||||
|
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type);
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !torch.none type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Checks whether the given type is a !torch.none type
|
||||||
|
bool npcompTypeIsATorchNone(MlirType t);
|
||||||
|
|
||||||
|
/// Gets the !torch.none type.
|
||||||
|
MlirType npcompTorchNoneTypeGet(MlirContext context);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif // NPCOMP_C_TORCHTYPES_H
|
|
@ -1,246 +0,0 @@
|
||||||
/*===-- npcomp-c/Types.h - NPComp custom types --------------------*- C -*-===*\
|
|
||||||
|* *|
|
|
||||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
|
||||||
|* Exceptions. *|
|
|
||||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
|
||||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
|
||||||
|* *|
|
|
||||||
\*===----------------------------------------------------------------------===*/
|
|
||||||
|
|
||||||
#ifndef NPCOMP_C_TYPES_H
|
|
||||||
#define NPCOMP_C_TYPES_H
|
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
extern "C" {
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Any dtype type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the special "any dtype" type that is used
|
|
||||||
* to signal an NDArray or tensor of unknown type. */
|
|
||||||
int npcompTypeIsAAnyDtype(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the "any dtype" type. */
|
|
||||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Bool type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "bool" type. */
|
|
||||||
int npcompTypeIsABool(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the Python "bool" type. */
|
|
||||||
MlirType npcompBoolTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Bytes type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "bytes" type. */
|
|
||||||
int npcompTypeIsABytes(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the Python "bytes" type. */
|
|
||||||
MlirType npcompBytesTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Dict type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "dict" type. */
|
|
||||||
int npcompTypeIsADict(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the generic Python "dict" type. */
|
|
||||||
MlirType npcompDictTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* List type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "list" type. */
|
|
||||||
int npcompTypeIsABasicpyList(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the generic Python "list" type. */
|
|
||||||
MlirType npcompBasicpyListTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* NDArray type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is an NdArray type. */
|
|
||||||
int npcompTypeIsANdArray(MlirType t);
|
|
||||||
|
|
||||||
/** Gets a numpy.NdArray type that is unranked. */
|
|
||||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType);
|
|
||||||
|
|
||||||
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
|
||||||
* unknown. */
|
|
||||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
|
||||||
MlirType elementType);
|
|
||||||
|
|
||||||
/// Helper that gets an equivalent NdArrayType from a ShapedType.
|
|
||||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType);
|
|
||||||
|
|
||||||
/// Helper that converts an NdArrayType to a TensorType.
|
|
||||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* !basicpy.NoneType type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
|
||||||
int npcompTypeIsABasicpyNone(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the `!basicpy.NoneType` type. */
|
|
||||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* SlotObject type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
|
||||||
intptr_t slotTypeCount,
|
|
||||||
const MlirType *slotTypes);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Tuple type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the special "any dtype" type that is used
|
|
||||||
* to signal an NDArray or tensor of unknown type. */
|
|
||||||
int npcompTypeIsATuple(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the generic Python "tuple" type. */
|
|
||||||
MlirType npcompTupleTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.nn.Module type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a torch.nn.Module type */
|
|
||||||
int npcompTypeIsANnModule(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.nn.Module type of the specified class. */
|
|
||||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.optional type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.optional<T> type */
|
|
||||||
int npcompTypeIsAOptional(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.optional<T> type with subtype T. */
|
|
||||||
MlirType npcompOptionalTypeGet(MlirType containedType);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.list type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.list<T> type */
|
|
||||||
int npcompTypeIsAList(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.list<T> type with contained T. */
|
|
||||||
MlirType npcompListTypeGet(MlirType containedType);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.Device type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.Device type */
|
|
||||||
int npcompTypeIsADevice(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.Device type. */
|
|
||||||
MlirType npcompDeviceTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.LinearParams type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.LinearParams type */
|
|
||||||
int npcompTypeIsALinearParams(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.LinearParams type. */
|
|
||||||
MlirType npcompLinearParamsTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.qint8 type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.qint8 type */
|
|
||||||
int npcompTypeIsAQInt8(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.qint8 type. */
|
|
||||||
MlirType npcompQInt8TypeGet(MlirContext context);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.tensor type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.tensor type */
|
|
||||||
int npcompTypeIsANonValueTensor(MlirType t);
|
|
||||||
|
|
||||||
/** Gets a !torch.tensor type.
|
|
||||||
*
|
|
||||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
|
||||||
* present (and `numSizes` is ignored in that case).
|
|
||||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
|
||||||
* present.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
|
||||||
const int64_t *optionalSizes,
|
|
||||||
MlirType optionalDtype);
|
|
||||||
|
|
||||||
/** Gets the !torch.tensor type with the least static information. */
|
|
||||||
MlirType
|
|
||||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
|
||||||
|
|
||||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
|
||||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.vtensor type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.vtensor type */
|
|
||||||
int npcompTypeIsAValueTensor(MlirType t);
|
|
||||||
|
|
||||||
/** Gets a !torch.vtensor type.
|
|
||||||
*
|
|
||||||
* - `optionalSizes` is allowed to be null, meaning that no size information is
|
|
||||||
* present (and `numSizes` is ignored in that case).
|
|
||||||
* - `optionalDtype` is allowed to be null, meaning that no dtype information is
|
|
||||||
* present.
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
|
||||||
const int64_t *optionalSizes,
|
|
||||||
MlirType optionalDtype);
|
|
||||||
|
|
||||||
/** Gets the !torch.tensor type with the least static information. */
|
|
||||||
MlirType
|
|
||||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
|
|
||||||
|
|
||||||
/** Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`. */
|
|
||||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type);
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* !torch.none type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.none type */
|
|
||||||
int npcompTypeIsATorchNone(MlirType t);
|
|
||||||
|
|
||||||
/** Gets the !torch.none type. */
|
|
||||||
MlirType npcompTorchNoneTypeGet(MlirContext context);
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#endif // NPCOMP_C_TYPES_H
|
|
|
@ -0,0 +1,116 @@
|
||||||
|
//===- BasicpyTypes.cpp - C Interface for basicpy types -------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
|
||||||
|
#include "mlir/CAPI/IR.h"
|
||||||
|
#include "mlir/CAPI/Support.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Bool type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsABasicpyBool(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::BoolType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompBasicpyBoolTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::BoolType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Bytes type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsABasicpyBytes(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::BytesType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompBasicpyBytesTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::BytesType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Dict type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is the Python "dict" type. */
|
||||||
|
bool npcompTypeIsABasicpyDict(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::DictType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the generic Python "dict" type. */
|
||||||
|
MlirType npcompBasicpyDictTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::DictType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// List type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is the Python "list" type. */
|
||||||
|
bool npcompTypeIsABasicpyList(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::ListType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the generic Python "dict" type. */
|
||||||
|
MlirType npcompBasicpyListTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::ListType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !basicpy.NoneType type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
||||||
|
bool npcompTypeIsANone(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::NoneType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the `!basicpy.NoneType` type. */
|
||||||
|
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// SlotObject type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
MlirType npcompBasicPySlotObjectTypeGet(MlirContext context,
|
||||||
|
MlirStringRef className,
|
||||||
|
intptr_t slotTypeCount,
|
||||||
|
const MlirType *slotTypes) {
|
||||||
|
MLIRContext *cppContext = unwrap(context);
|
||||||
|
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
|
||||||
|
SmallVector<Type> slotTypesCpp;
|
||||||
|
slotTypesCpp.resize(slotTypeCount);
|
||||||
|
for (intptr_t i = 0; i < slotTypeCount; ++i) {
|
||||||
|
slotTypesCpp[i] = unwrap(slotTypes[i]);
|
||||||
|
}
|
||||||
|
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Tuple type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is the special "any dtype" type that is used
|
||||||
|
* to signal an NDArray or tensor of unknown type. */
|
||||||
|
bool npcompTypeIsABasicpyTuple(MlirType t) {
|
||||||
|
return unwrap(t).isa<Basicpy::TupleType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the "any dtype" type. */
|
||||||
|
MlirType npcompBasicpyTupleTypeGet(MlirContext context) {
|
||||||
|
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
||||||
|
}
|
|
@ -7,7 +7,9 @@ set(LLVM_LINK_COMPONENTS
|
||||||
add_npcomp_library(NPCOMPCAPI
|
add_npcomp_library(NPCOMPCAPI
|
||||||
InitLLVM.cpp
|
InitLLVM.cpp
|
||||||
Registration.cpp
|
Registration.cpp
|
||||||
Types.cpp
|
BasicpyTypes.cpp
|
||||||
|
NumpyTypes.cpp
|
||||||
|
TorchTypes.cpp
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
LINK_LIBS PUBLIC
|
||||||
MLIRExecutionEngine
|
MLIRExecutionEngine
|
||||||
|
|
|
@ -0,0 +1,56 @@
|
||||||
|
//===- NumpyTypes.cpp - C Interface for numpy types -----------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp-c/NumpyTypes.h"
|
||||||
|
|
||||||
|
#include "mlir/CAPI/IR.h"
|
||||||
|
#include "mlir/CAPI/Support.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// !numpy.any_dtype type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsANumpyAnyDtype(MlirType t) {
|
||||||
|
return unwrap(t).isa<Numpy::AnyDtypeType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
||||||
|
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// NDArray type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsANumpyNdArray(MlirType t) {
|
||||||
|
return unwrap(t).isa<Numpy::NdArrayType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetUnranked(MlirType elementType) {
|
||||||
|
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||||
|
MlirType elementType) {
|
||||||
|
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||||
|
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompNumpyNdArrayTypeGetFromShaped(MlirType shapedType) {
|
||||||
|
return wrap(Numpy::NdArrayType::getFromShapedType(
|
||||||
|
unwrap(shapedType).cast<ShapedType>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompNumpyNdArrayTypeToTensor(MlirType ndarrayType) {
|
||||||
|
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
|
||||||
|
}
|
|
@ -0,0 +1,165 @@
|
||||||
|
//===- TorchTypes.cpp - C Interface for torch types -----------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "npcomp-c/TorchTypes.h"
|
||||||
|
|
||||||
|
#include "mlir/CAPI/IR.h"
|
||||||
|
#include "mlir/CAPI/Support.h"
|
||||||
|
#include "mlir/IR/BuiltinTypes.h"
|
||||||
|
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
using namespace mlir::NPCOMP;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.nn.Module type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is a torch.nn.Module type */
|
||||||
|
bool npcompTypeIsATorchNnModule(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::NnModuleType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the torch.nn.Module type of the specified class. */
|
||||||
|
MlirType npcompTorchNnModuleTypeGet(MlirContext context,
|
||||||
|
MlirStringRef className) {
|
||||||
|
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.optional type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/** Checks whether the given type is a !torch.optional<T> type */
|
||||||
|
bool npcompTypeIsATorchOptional(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::OptionalType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Gets the !torch.optional<T> type with subtype T. */
|
||||||
|
MlirType npcompTorchOptionalTypeGet(MlirType containedType) {
|
||||||
|
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.list<T> type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchList(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::ListType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchListTypeGet(MlirType containedType) {
|
||||||
|
return wrap(Torch::ListType::get(unwrap(containedType)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.Device type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchDevice(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::DeviceType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchDeviceTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::DeviceType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.LinearParams type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchLinearParams(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::LinearParamsType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchLinearParamsTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.qint8 type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchQInt8(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::QInt8Type>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchQInt8TypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.tensor type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchNonValueTensor(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::NonValueTensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchNonValueTensorTypeGet(MlirContext context,
|
||||||
|
intptr_t numSizes,
|
||||||
|
const int64_t *optionalSizes,
|
||||||
|
MlirType optionalDtype) {
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||||
|
if (optionalSizes)
|
||||||
|
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||||
|
return wrap(Torch::NonValueTensorType::get(
|
||||||
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchNonValueTensorTypeGetWithLeastStaticInformation(
|
||||||
|
MlirContext context) {
|
||||||
|
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
||||||
|
unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchNonValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
|
return wrap(Torch::NonValueTensorType::getFromShaped(
|
||||||
|
unwrap(type).cast<ShapedType>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.vtensor type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchValueTensor(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::ValueTensorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
||||||
|
const int64_t *optionalSizes,
|
||||||
|
MlirType optionalDtype) {
|
||||||
|
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
||||||
|
if (optionalSizes)
|
||||||
|
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
||||||
|
return wrap(Torch::ValueTensorType::get(
|
||||||
|
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType
|
||||||
|
npcompTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
||||||
|
return wrap(
|
||||||
|
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchValueTensorTypeGetFromShaped(MlirType type) {
|
||||||
|
return wrap(
|
||||||
|
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// torch.none type.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
bool npcompTypeIsATorchNone(MlirType t) {
|
||||||
|
return unwrap(t).isa<Torch::NoneType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
||||||
|
return wrap(Torch::NoneType::get(unwrap(context)));
|
||||||
|
}
|
|
@ -1,303 +0,0 @@
|
||||||
//===- Types.cpp - C Interface for NPComp types ---------------------------===//
|
|
||||||
//
|
|
||||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
||||||
// See https://llvm.org/LICENSE.txt for license information.
|
|
||||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
||||||
//
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
#include "npcomp-c/Types.h"
|
|
||||||
|
|
||||||
#include "mlir/CAPI/IR.h"
|
|
||||||
#include "mlir/CAPI/Support.h"
|
|
||||||
#include "mlir/IR/BuiltinTypes.h"
|
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
|
||||||
#include "npcomp/Dialect/Torch/IR/TorchTypes.h"
|
|
||||||
|
|
||||||
using namespace mlir;
|
|
||||||
using namespace mlir::NPCOMP;
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Any dtype type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsAAnyDtype(MlirType t) {
|
|
||||||
return unwrap(t).isa<Numpy::AnyDtypeType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
|
||||||
return wrap(Numpy::AnyDtypeType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Bool type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<Basicpy::BoolType>(); }
|
|
||||||
|
|
||||||
MlirType npcompBoolTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::BoolType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Bytes type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsABytes(MlirType t) {
|
|
||||||
return unwrap(t).isa<Basicpy::BytesType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompBytesTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::BytesType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Dict type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "dict" type. */
|
|
||||||
int npcompTypeIsADict(MlirType t) { return unwrap(t).isa<Basicpy::DictType>(); }
|
|
||||||
|
|
||||||
/** Gets the generic Python "dict" type. */
|
|
||||||
MlirType npcompDictTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::DictType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* List type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the Python "list" type. */
|
|
||||||
int npcompTypeIsABasicpyList(MlirType t) { return unwrap(t).isa<Basicpy::ListType>(); }
|
|
||||||
|
|
||||||
/** Gets the generic Python "dict" type. */
|
|
||||||
MlirType npcompBasicpyListTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::ListType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* NDArray type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsANdArray(MlirType t) {
|
|
||||||
return unwrap(t).isa<Numpy::NdArrayType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNdArrayTypeGetUnranked(MlirType elementType) {
|
|
||||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
|
||||||
MlirType elementType) {
|
|
||||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
|
||||||
return wrap(Numpy::NdArrayType::get(unwrap(elementType), shapeArray));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNdArrayTypeGetFromShaped(MlirType shapedType) {
|
|
||||||
return wrap(Numpy::NdArrayType::getFromShapedType(
|
|
||||||
unwrap(shapedType).cast<ShapedType>()));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNdArrayTypeToTensor(MlirType ndarrayType) {
|
|
||||||
return wrap(unwrap(ndarrayType).cast<Numpy::NdArrayType>().toTensorType());
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* !basicpy.NoneType type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a `!basicpy.NoneType`. */
|
|
||||||
int npcompTypeIsANone(MlirType t) { return unwrap(t).isa<Basicpy::NoneType>(); }
|
|
||||||
|
|
||||||
/** Gets the `!basicpy.NoneType` type. */
|
|
||||||
MlirType npcompBasicpyNoneTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::NoneType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* SlotObject type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
MlirType npcompSlotObjectTypeGet(MlirContext context, MlirStringRef className,
|
|
||||||
intptr_t slotTypeCount,
|
|
||||||
const MlirType *slotTypes) {
|
|
||||||
MLIRContext *cppContext = unwrap(context);
|
|
||||||
auto classNameAttr = StringAttr::get(cppContext, unwrap(className));
|
|
||||||
SmallVector<Type> slotTypesCpp;
|
|
||||||
slotTypesCpp.resize(slotTypeCount);
|
|
||||||
for (intptr_t i = 0; i < slotTypeCount; ++i) {
|
|
||||||
slotTypesCpp[i] = unwrap(slotTypes[i]);
|
|
||||||
}
|
|
||||||
return wrap(Basicpy::SlotObjectType::get(classNameAttr, slotTypesCpp));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* Tuple type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is the special "any dtype" type that is used
|
|
||||||
* to signal an NDArray or tensor of unknown type. */
|
|
||||||
int npcompTypeIsATuple(MlirType t) {
|
|
||||||
return unwrap(t).isa<Basicpy::TupleType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the "any dtype" type. */
|
|
||||||
MlirType npcompTupleTypeGet(MlirContext context) {
|
|
||||||
return wrap(Basicpy::TupleType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.nn.Module type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a torch.nn.Module type */
|
|
||||||
int npcompTypeIsANnModule(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::NnModuleType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the torch.nn.Module type of the specified class. */
|
|
||||||
MlirType npcompNnModuleTypeGet(MlirContext context, MlirStringRef className) {
|
|
||||||
return wrap(Torch::NnModuleType::get(unwrap(context), unwrap(className)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.optional type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.optional<T> type */
|
|
||||||
int npcompTypeIsAOptional(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::OptionalType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.optional<T> type with subtype T. */
|
|
||||||
MlirType npcompOptionalTypeGet(MlirType containedType) {
|
|
||||||
return wrap(Torch::OptionalType::get(unwrap(containedType)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.list type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
/** Checks whether the given type is a !torch.list<T> type */
|
|
||||||
int npcompTypeIsAList(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::ListType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.List<T> type with contained T. */
|
|
||||||
MlirType npcompListTypeGet(MlirType containedType) {
|
|
||||||
return wrap(Torch::ListType::get(unwrap(containedType)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.Device type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.Device type */
|
|
||||||
int npcompTypeIsADevice(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::DeviceType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.Device type. */
|
|
||||||
MlirType npcompDeviceTypeGet(MlirContext context) {
|
|
||||||
return wrap(Torch::DeviceType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.LinearParams type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.LinearParams type */
|
|
||||||
int npcompTypeIsALinearParams(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::LinearParamsType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.LinearParams type. */
|
|
||||||
MlirType npcompLinearParamsTypeGet(MlirContext context) {
|
|
||||||
return wrap(Torch::LinearParamsType::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.qint8 type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.qint8 type */
|
|
||||||
int npcompTypeIsAQInt8(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::QInt8Type>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.qint8 type. */
|
|
||||||
MlirType npcompQInt8TypeGet(MlirContext context) {
|
|
||||||
return wrap(Torch::QInt8Type::get(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.tensor type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsANonValueTensor(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::NonValueTensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNonValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
|
||||||
const int64_t *optionalSizes,
|
|
||||||
MlirType optionalDtype) {
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
|
||||||
if (optionalSizes)
|
|
||||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
|
||||||
return wrap(Torch::NonValueTensorType::get(
|
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType
|
|
||||||
npcompNonValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
|
||||||
return wrap(Torch::NonValueTensorType::getWithLeastStaticInformation(
|
|
||||||
unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompNonValueTensorTypeGetFromShaped(MlirType type) {
|
|
||||||
return wrap(Torch::NonValueTensorType::getFromShaped(
|
|
||||||
unwrap(type).cast<ShapedType>()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.vtensor type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
int npcompTypeIsAValueTensor(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::ValueTensorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompValueTensorTypeGet(MlirContext context, intptr_t numSizes,
|
|
||||||
const int64_t *optionalSizes,
|
|
||||||
MlirType optionalDtype) {
|
|
||||||
Optional<ArrayRef<int64_t>> optionalSizesArrayRef = None;
|
|
||||||
if (optionalSizes)
|
|
||||||
optionalSizesArrayRef = llvm::makeArrayRef(optionalSizes, numSizes);
|
|
||||||
return wrap(Torch::ValueTensorType::get(
|
|
||||||
unwrap(context), optionalSizesArrayRef, unwrap(optionalDtype)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType
|
|
||||||
npcompValueTensorTypeGetWithLeastStaticInformation(MlirContext context) {
|
|
||||||
return wrap(
|
|
||||||
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
|
|
||||||
}
|
|
||||||
|
|
||||||
MlirType npcompValueTensorTypeGetFromShaped(MlirType type) {
|
|
||||||
return wrap(
|
|
||||||
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
|
|
||||||
}
|
|
||||||
|
|
||||||
/*============================================================================*/
|
|
||||||
/* torch.none type. */
|
|
||||||
/*============================================================================*/
|
|
||||||
|
|
||||||
/** Checks whether the given type is a !torch.none type */
|
|
||||||
int npcompTypeIsATorchNone(MlirType t) {
|
|
||||||
return unwrap(t).isa<Torch::NoneType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Gets the !torch.none type. */
|
|
||||||
MlirType npcompTorchNoneTypeGet(MlirContext context) {
|
|
||||||
return wrap(Torch::NoneType::get(unwrap(context)));
|
|
||||||
}
|
|
|
@ -12,9 +12,10 @@
|
||||||
#include "mlir-c/BuiltinAttributes.h"
|
#include "mlir-c/BuiltinAttributes.h"
|
||||||
#include "mlir-c/BuiltinTypes.h"
|
#include "mlir-c/BuiltinTypes.h"
|
||||||
#include "mlir-c/Diagnostics.h"
|
#include "mlir-c/Diagnostics.h"
|
||||||
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
#include "npcomp-c/InitLLVM.h"
|
#include "npcomp-c/InitLLVM.h"
|
||||||
|
#include "npcomp-c/NumpyTypes.h"
|
||||||
#include "npcomp-c/Registration.h"
|
#include "npcomp-c/Registration.h"
|
||||||
#include "npcomp-c/Types.h"
|
|
||||||
#include "npcomp/Python/PybindUtils.h"
|
#include "npcomp/Python/PybindUtils.h"
|
||||||
|
|
||||||
#ifdef NPCOMP_ENABLE_REFJIT
|
#ifdef NPCOMP_ENABLE_REFJIT
|
||||||
|
@ -27,21 +28,21 @@ MlirType shapedToNdArrayArrayType(MlirType shaped_type) {
|
||||||
if (!mlirTypeIsAShaped(shaped_type)) {
|
if (!mlirTypeIsAShaped(shaped_type)) {
|
||||||
throw py::raiseValueError("type is not a shaped type");
|
throw py::raiseValueError("type is not a shaped type");
|
||||||
}
|
}
|
||||||
return npcompNdArrayTypeGetFromShaped(shaped_type);
|
return npcompNumpyNdArrayTypeGetFromShaped(shaped_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType ndarrayToTensorType(MlirType ndarray_type) {
|
MlirType ndarrayToTensorType(MlirType ndarray_type) {
|
||||||
if (!npcompTypeIsANdArray(ndarray_type)) {
|
if (!npcompTypeIsANumpyNdArray(ndarray_type)) {
|
||||||
throw py::raiseValueError("type is not an ndarray type");
|
throw py::raiseValueError("type is not an ndarray type");
|
||||||
}
|
}
|
||||||
return npcompNdArrayTypeToTensor(ndarray_type);
|
return npcompNumpyNdArrayTypeToTensor(ndarray_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
MlirType slotObjectType(MlirContext context, const std::string &className,
|
MlirType slotObjectType(MlirContext context, const std::string &className,
|
||||||
const std::vector<MlirType> &slotTypes) {
|
const std::vector<MlirType> &slotTypes) {
|
||||||
MlirStringRef classNameSr{className.data(), className.size()};
|
MlirStringRef classNameSr{className.data(), className.size()};
|
||||||
return ::npcompSlotObjectTypeGet(context, classNameSr, slotTypes.size(),
|
return ::npcompBasicPySlotObjectTypeGet(context, classNameSr,
|
||||||
slotTypes.data());
|
slotTypes.size(), slotTypes.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Move this upstream.
|
// TODO: Move this upstream.
|
||||||
|
|
|
@ -12,8 +12,9 @@
|
||||||
|
|
||||||
#include "mlir-c/IR.h"
|
#include "mlir-c/IR.h"
|
||||||
#include "mlir-c/Registration.h"
|
#include "mlir-c/Registration.h"
|
||||||
|
#include "npcomp-c/BasicpyTypes.h"
|
||||||
|
#include "npcomp-c/NumpyTypes.h"
|
||||||
#include "npcomp-c/Registration.h"
|
#include "npcomp-c/Registration.h"
|
||||||
#include "npcomp-c/Types.h"
|
|
||||||
|
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
@ -24,30 +25,31 @@
|
||||||
// Dumps an instance of all NPComp types.
|
// Dumps an instance of all NPComp types.
|
||||||
static int printStandardTypes(MlirContext ctx) {
|
static int printStandardTypes(MlirContext ctx) {
|
||||||
// Bool type.
|
// Bool type.
|
||||||
MlirType boolType = npcompBoolTypeGet(ctx);
|
MlirType boolType = npcompBasicpyBoolTypeGet(ctx);
|
||||||
if (!npcompTypeIsABool(boolType))
|
if (!npcompTypeIsABasicpyBool(boolType))
|
||||||
return 1;
|
return 1;
|
||||||
mlirTypeDump(boolType);
|
mlirTypeDump(boolType);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
// Bytes type.
|
// Bytes type.
|
||||||
MlirType bytesType = npcompBytesTypeGet(ctx);
|
MlirType bytesType = npcompBasicpyBytesTypeGet(ctx);
|
||||||
if (!npcompTypeIsABytes(bytesType))
|
if (!npcompTypeIsABasicpyBytes(bytesType))
|
||||||
return 1;
|
return 1;
|
||||||
mlirTypeDump(bytesType);
|
mlirTypeDump(bytesType);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
// Any dtype.
|
// Any dtype.
|
||||||
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
|
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
|
||||||
if (!npcompTypeIsAAnyDtype(anyDtype))
|
if (!npcompTypeIsANumpyAnyDtype(anyDtype))
|
||||||
return 2;
|
return 2;
|
||||||
mlirTypeDump(anyDtype);
|
mlirTypeDump(anyDtype);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
// Ranked NdArray.
|
// Ranked NdArray.
|
||||||
int64_t fourDim = 4;
|
int64_t fourDim = 4;
|
||||||
MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType);
|
MlirType rankedNdArray =
|
||||||
if (!npcompTypeIsANdArray(rankedNdArray))
|
npcompNumpyNdArrayTypeGetRanked(1, &fourDim, boolType);
|
||||||
|
if (!npcompTypeIsANumpyNdArray(rankedNdArray))
|
||||||
return 3;
|
return 3;
|
||||||
mlirTypeDump(rankedNdArray);
|
mlirTypeDump(rankedNdArray);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
|
Loading…
Reference in New Issue