mirror of https://github.com/llvm/torch-mlir
Add capture function arguments.
* Adds at::Tensor -> MlirValue tracking. * Adds conversions for tensor and scalar types to MLIR types. * Adds npcomp C APIs for constructing custom types. * Reworks pybind include so as to get Torch pybind helpers (needed to pass at::Tensor type from Python->C++).pull/64/head
parent
3d74337be0
commit
e5433e314f
|
@ -9,13 +9,14 @@ include_directories(
|
|||
link_directories("${TORCH_INSTALL_PREFIX}/lib")
|
||||
add_library(npcomp_torch_c10_dispatch_bindings
|
||||
acap_dispatch.cpp
|
||||
func_builder.cpp
|
||||
module_builder.cpp
|
||||
python_bindings.cpp
|
||||
)
|
||||
|
||||
get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS)
|
||||
target_link_libraries(npcomp_torch_c10_dispatch_bindings
|
||||
NPCOMPCAPIRegistration
|
||||
NPCOMPCAPIIR
|
||||
${TORCH_LIBRARIES}
|
||||
${PYTHON_LIBRARIES}
|
||||
${mlir_libs}
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
#include <list>
|
||||
#include <memory>
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "../pybind.h"
|
||||
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
|
@ -29,10 +31,8 @@ namespace torch_mlir {
|
|||
/// Main entry point for managing device capture.
|
||||
class AcapController : public std::enable_shared_from_this<AcapController> {
|
||||
public:
|
||||
AcapController(MlirOperation funcOp) : funcOp(funcOp) {
|
||||
// TODO: Remove once used (suppresses warning).
|
||||
(void)this->funcOp;
|
||||
}
|
||||
AcapController(std::unique_ptr<FuncBuilder> funcBuilder)
|
||||
: funcBuilder(std::move(funcBuilder)) {}
|
||||
|
||||
// Enter and exit the context manager.
|
||||
pybind11::object contextEnter();
|
||||
|
@ -63,7 +63,7 @@ private:
|
|||
// Gets the thread local stack of active acap controllers.
|
||||
static std::list<Activation> &getThreadLocalActiveStack();
|
||||
|
||||
MlirOperation funcOp;
|
||||
std::unique_ptr<FuncBuilder> funcBuilder;
|
||||
std::vector<std::string> captureLog;
|
||||
};
|
||||
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
//===- func_builder.cpp ---------------------------------------------------===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "func_builder.h"
|
||||
|
||||
#include "mlir-c/StandardAttributes.h"
|
||||
#include "mlir-c/StandardTypes.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
|
||||
using namespace torch_mlir;
|
||||
|
||||
MlirType TypeMapper::mapScalarType(c10::ScalarType scalarType) {
|
||||
using c10::ScalarType;
|
||||
switch (scalarType) {
|
||||
case ScalarType::Byte:
|
||||
return mlirIntegerTypeUnsignedGet(context, 8);
|
||||
case ScalarType::Char:
|
||||
return mlirIntegerTypeSignedGet(context, 8);
|
||||
case ScalarType::Short:
|
||||
return mlirIntegerTypeSignedGet(context, 16);
|
||||
case ScalarType::Int:
|
||||
return mlirIntegerTypeSignedGet(context, 32);
|
||||
case ScalarType::Long:
|
||||
return mlirIntegerTypeSignedGet(context, 64);
|
||||
case ScalarType::Bool:
|
||||
return npcompBoolTypeGet(context);
|
||||
case ScalarType::Double:
|
||||
return mlirF64TypeGet(context);
|
||||
case ScalarType::Float:
|
||||
return mlirF32TypeGet(context);
|
||||
case ScalarType::BFloat16:
|
||||
return mlirBF16TypeGet(context);
|
||||
case ScalarType::Half:
|
||||
return mlirF16TypeGet(context);
|
||||
default: {
|
||||
std::stringstream message;
|
||||
message << "unsupported PyTorch scalar type: " << c10::toString(scalarType);
|
||||
throw std::invalid_argument(message.str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MlirType TypeMapper::forwardTensorToType(at::Tensor tensor) {
|
||||
if (!tensor.defined())
|
||||
throw std::invalid_argument("Tensor is not defined");
|
||||
|
||||
MlirType elementType = mapScalarType(tensor.scalar_type());
|
||||
// TODO: Decide when it is necessary to take strides into account. Right now,
|
||||
// just erase them and let the compiler decide.
|
||||
|
||||
auto sizes = tensor.sizes();
|
||||
return npcompNdArrayTypeGetRanked(sizes.size(), sizes.data(), elementType);
|
||||
}
|
||||
|
||||
static MlirOperation createEmptyReturnOp(MlirLocation location) {
|
||||
MlirOperationState state = mlirOperationStateGet("std.return", location);
|
||||
return mlirOperationCreate(&state);
|
||||
}
|
||||
|
||||
std::unique_ptr<FuncBuilder>
|
||||
FuncBuilder::createFunction(MlirContext context, MlirLocation location,
|
||||
llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes) {
|
||||
// TODO: Create a dedicated API upstream for creating/manipulating func ops.
|
||||
// (this is fragile and reveals details that are not guaranteed).
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
"type", mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context, inputTypes.size(), inputTypes.data(),
|
||||
/*numResults=*/0, /*results=*/nullptr))));
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
"sym_name", mlirStringAttrGet(context, name.size(), name.data())));
|
||||
|
||||
MlirOperationState state = mlirOperationStateGet("func", location);
|
||||
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
||||
{
|
||||
// Don't access these once ownership transferred.
|
||||
MlirRegion newBodyRegion = mlirRegionCreate();
|
||||
MlirBlock newEntryBlock =
|
||||
mlirBlockCreate(inputTypes.size(), inputTypes.data());
|
||||
mlirRegionInsertOwnedBlockAfter(newBodyRegion, {nullptr}, newEntryBlock);
|
||||
mlirOperationStateAddOwnedRegions(&state, 1, &newBodyRegion);
|
||||
}
|
||||
|
||||
// Need to re-lookup the region/block because we relinquished ownership above.
|
||||
MlirOperation funcOp = mlirOperationCreate(&state);
|
||||
MlirRegion bodyRegion = mlirOperationGetRegion(funcOp, 0);
|
||||
MlirBlock entryBlock = mlirRegionGetFirstBlock(bodyRegion);
|
||||
|
||||
// Create an empty return op (will rework it later as return types become
|
||||
// known).
|
||||
MlirOperation returnOp = createEmptyReturnOp(location);
|
||||
mlirBlockInsertOwnedOperationBefore(entryBlock, {nullptr}, returnOp);
|
||||
|
||||
return std::unique_ptr<FuncBuilder>(new FuncBuilder(
|
||||
context, funcOp, BlockBuilder(entryBlock, returnOp, true)));
|
||||
}
|
||||
|
||||
MlirValue FuncBuilder::lookupTensor(at::Tensor tensor) {
|
||||
for (auto it = tensorValueMap.rbegin(), e = tensorValueMap.rend(); it != e;
|
||||
++it) {
|
||||
if (it->first.is_same(tensor))
|
||||
return it->second;
|
||||
}
|
||||
return {nullptr};
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
//===- func_builder.h -------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_FUNC_BUILDER_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_FUNC_BUILDER_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Maps various runtime types to MlirType.
|
||||
class TypeMapper {
|
||||
public:
|
||||
TypeMapper(MlirContext context) : context(context) {}
|
||||
|
||||
/// Gets a corresponding MlirType for the Torch ScalarType.
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType mapScalarType(c10::ScalarType scalarType);
|
||||
|
||||
/// Gets a corresponding MlirType for the forward component of a tensor.
|
||||
/// Throws std::invalid_argument on failure.
|
||||
MlirType forwardTensorToType(at::Tensor tensor);
|
||||
|
||||
private:
|
||||
MlirContext context;
|
||||
};
|
||||
|
||||
/// Wraps an MlirBlock under construction, primarily tracking the terminator
|
||||
/// and supporting manipulation of it. The terminator may be null if it has
|
||||
/// not yet been constructed, although, for entry blocks, we always construct
|
||||
/// the function with an appropriate return terminator (which can be changed
|
||||
/// later).
|
||||
class BlockBuilder {
|
||||
public:
|
||||
BlockBuilder(MlirBlock block, MlirOperation terminator, bool isReturn)
|
||||
: block(block), terminator(terminator), isReturn(isReturn) {}
|
||||
|
||||
MlirBlock getBlock() { return block; }
|
||||
MlirOperation getTerminator() { return terminator; }
|
||||
bool getIsReturnTerminator() { return isReturn; }
|
||||
|
||||
private:
|
||||
MlirBlock block;
|
||||
MlirOperation terminator;
|
||||
bool isReturn;
|
||||
};
|
||||
|
||||
/// Wraps a 'func' MlirOperation and provides facilities for constructing
|
||||
/// IR from some stream of Torch operations.
|
||||
class FuncBuilder {
|
||||
public:
|
||||
/// Creates a new func op with the given characteristics. The created
|
||||
/// operation is not attached. The caller must either destroy it or add it
|
||||
/// to a parent.
|
||||
static std::unique_ptr<FuncBuilder>
|
||||
createFunction(MlirContext context, MlirLocation location,
|
||||
llvm::StringRef name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes);
|
||||
|
||||
MlirOperation getFuncOp() { return funcOp; }
|
||||
|
||||
/// Gets the function's entry block.
|
||||
MlirBlock getEntryBlock() { return entryBlock.getBlock(); }
|
||||
|
||||
/// Maps a live Tensor to an MlirValue.
|
||||
void mapTensor(at::Tensor tensor, MlirValue value) {
|
||||
tensorValueMap.push_back(std::make_pair(tensor, value));
|
||||
}
|
||||
|
||||
/// Looks up a current mapping of tensor to an MlirValue, returning a null
|
||||
/// value if not found.
|
||||
MlirValue lookupTensor(at::Tensor tensor);
|
||||
|
||||
private:
|
||||
FuncBuilder(MlirContext context, MlirOperation funcOp,
|
||||
BlockBuilder entryBlock)
|
||||
: context(context), funcOp(funcOp), entryBlock(std::move(entryBlock)) {
|
||||
(void)this->context;
|
||||
}
|
||||
|
||||
MlirContext context;
|
||||
|
||||
/// The func op under construction.
|
||||
MlirOperation funcOp;
|
||||
|
||||
/// Block builder for the entry block.
|
||||
BlockBuilder entryBlock;
|
||||
|
||||
/// Maps tensors to MlirValue. Unfortunately, this needs to be a linear scan
|
||||
/// because the impl pointer for the Tensor is not accessible. To make this
|
||||
/// slightly better, we add to the back and lookup in reverse under the idea
|
||||
/// that tensors may be mapped and accessed in proximity.
|
||||
llvm::SmallVector<std::pair<at::Tensor, MlirValue>, 16> tensorValueMap;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
|
@ -46,11 +46,14 @@ ModuleBuilder::ModuleBuilder()
|
|||
// semantics and interop). Until then, they are just scoped to this instance
|
||||
// and must not escape.
|
||||
: context(mlirContextCreate()), unknownLoc(mlirLocationUnknownGet(context)),
|
||||
module(mlirModuleCreateEmpty(unknownLoc)) {
|
||||
module(mlirModuleCreateEmpty(unknownLoc)), typeMapper(context) {
|
||||
// TODO: Rework this once dialect registration C-APIs are in place.
|
||||
// https://reviews.llvm.org/D88162
|
||||
mlirRegisterAllDialects(context);
|
||||
npcompRegisterAllDialects(context);
|
||||
|
||||
// Terminator will always be the first op of an empty module.
|
||||
terminator = mlirBlockGetFirstOperation(getBodyBlock());
|
||||
}
|
||||
|
||||
ModuleBuilder::~ModuleBuilder() {
|
||||
|
@ -67,47 +70,35 @@ py::str ModuleBuilder::getAsm() {
|
|||
}
|
||||
|
||||
std::shared_ptr<AcapController>
|
||||
ModuleBuilder::startCaptureFunction(std::string &name) {
|
||||
// TODO: Populate input/result types.
|
||||
ModuleBuilder::startCaptureFunction(std::string &name,
|
||||
std::vector<at::Tensor> args) {
|
||||
// TODO: Verify that arguments do not alias each other.
|
||||
llvm::SmallVector<MlirType, 4> inputTypes;
|
||||
llvm::SmallVector<MlirType, 4> resultTypes;
|
||||
MlirOperation funcOp = createFunction(name, inputTypes, resultTypes);
|
||||
return std::make_shared<AcapController>(funcOp);
|
||||
}
|
||||
|
||||
// TODO: Implement an mlir-c API for creating a function and avoid the danger
|
||||
// of getting the below wrong.
|
||||
MlirOperation
|
||||
ModuleBuilder::createFunction(std::string &name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes,
|
||||
llvm::SmallVectorImpl<MlirType> &resultTypes) {
|
||||
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
||||
MlirBlock moduleBlock =
|
||||
mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||
|
||||
llvm::SmallVector<MlirNamedAttribute, 4> funcAttrs;
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
"type", mlirTypeAttrGet(mlirFunctionTypeGet(
|
||||
context, inputTypes.size(), inputTypes.data(),
|
||||
resultTypes.size(), resultTypes.data()))));
|
||||
funcAttrs.push_back(mlirNamedAttributeGet(
|
||||
"sym_name", mlirStringAttrGet(context, name.size(), name.data())));
|
||||
|
||||
// TODO: Extract current traceback and use it for location.
|
||||
MlirOperationState state = mlirOperationStateGet("func", unknownLoc);
|
||||
mlirOperationStateAddAttributes(&state, funcAttrs.size(), funcAttrs.data());
|
||||
{
|
||||
// Don't access these once ownership transferred.
|
||||
MlirRegion bodyRegion = mlirRegionCreate();
|
||||
MlirBlock entryBlock =
|
||||
mlirBlockCreate(inputTypes.size(), inputTypes.data());
|
||||
mlirRegionInsertOwnedBlockAfter(bodyRegion, {nullptr}, entryBlock);
|
||||
mlirOperationStateAddOwnedRegions(&state, 1, &bodyRegion);
|
||||
for (auto &arg : args) {
|
||||
inputTypes.push_back(typeMapper.forwardTensorToType(arg));
|
||||
}
|
||||
|
||||
MlirOperation funcOp = mlirOperationCreate(&state);
|
||||
mlirBlockInsertOwnedOperationAfter(moduleBlock, {nullptr}, funcOp);
|
||||
return funcOp;
|
||||
// TODO: Extract a traceback and use in place of unknownLoc.
|
||||
auto funcBuilder =
|
||||
FuncBuilder::createFunction(context, unknownLoc, name, inputTypes);
|
||||
mlirBlockInsertOwnedOperationBefore(getBodyBlock(), terminator,
|
||||
funcBuilder->getFuncOp());
|
||||
|
||||
// Map block arguments.
|
||||
MlirBlock entryBlock = funcBuilder->getEntryBlock();
|
||||
assert(mlirBlockGetNumArguments(entryBlock) ==
|
||||
static_cast<intptr_t>(args.size()) &&
|
||||
"entry block incorrect arg arity");
|
||||
for (auto it : llvm::enumerate(args)) {
|
||||
funcBuilder->mapTensor(it.value(),
|
||||
mlirBlockGetArgument(entryBlock, it.index()));
|
||||
}
|
||||
return std::make_shared<AcapController>(std::move(funcBuilder));
|
||||
}
|
||||
|
||||
MlirBlock ModuleBuilder::getBodyBlock() {
|
||||
MlirOperation moduleOp = mlirModuleGetOperation(module);
|
||||
return mlirRegionGetFirstBlock(mlirOperationGetRegion(moduleOp, 0));
|
||||
}
|
||||
|
||||
void ModuleBuilder::bind(py::module &m) {
|
||||
|
|
|
@ -8,13 +8,16 @@
|
|||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_C10_DISPATCH_MODULE_BUILDER_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
// TODO: Remove this dep once the getAsm() method is removed.
|
||||
#include "../pybind.h"
|
||||
|
||||
#include "acap_dispatch.h"
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
/// Main entry-point for constructing an MLIR module from some combination
|
||||
|
@ -32,17 +35,18 @@ public:
|
|||
|
||||
// Starts a device-capture based function.
|
||||
// TODO: Add inputs.
|
||||
std::shared_ptr<AcapController> startCaptureFunction(std::string &name);
|
||||
std::shared_ptr<AcapController>
|
||||
startCaptureFunction(std::string &name, std::vector<at::Tensor> args);
|
||||
|
||||
private:
|
||||
// Creates a new top-level function and returns its operation.
|
||||
MlirOperation createFunction(std::string &name,
|
||||
llvm::SmallVectorImpl<MlirType> &inputTypes,
|
||||
llvm::SmallVectorImpl<MlirType> &resultTypes);
|
||||
MlirBlock getBodyBlock();
|
||||
|
||||
MlirContext context;
|
||||
MlirLocation unknownLoc;
|
||||
MlirModule module;
|
||||
MlirOperation terminator;
|
||||
|
||||
TypeMapper typeMapper;
|
||||
};
|
||||
|
||||
} // namespace torch_mlir
|
||||
|
|
|
@ -5,9 +5,9 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "../pybind.h"
|
||||
|
||||
#include <ATen/core/dispatch/Dispatcher.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "../init_python_bindings.h"
|
||||
#include "acap_dispatch.h"
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
#ifndef INIT_PYTHON_BINDINGS_H
|
||||
#define INIT_PYTHON_BINDINGS_H
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "pybind.h"
|
||||
|
||||
namespace torch_mlir {
|
||||
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
//===- module_builder.h -----------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under a pytorch-style license
|
||||
// See frontends/pytorch/LICENSE for license information.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Includes Torch-specific pybind and associated helpers.
|
||||
// Depend on this for access to all Torch types (versus depending on pybind11
|
||||
// directly).
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H
|
||||
#define NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H
|
||||
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#endif // NPCOMP_FRONTENDS_PYTORCH_CSRC_PYBIND_H
|
|
@ -18,7 +18,7 @@
|
|||
// In this case t2_cpu contains the result of the computation, and t2_mlir
|
||||
// contains the mlir description of the computation.
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "../pybind.h"
|
||||
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/MemoryBuffer.h"
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
# See frontends/pytorch/LICENSE for license information.
|
||||
# RUN: python %s | FileCheck %s
|
||||
|
||||
# TODO: Once stabilized, expand tests to include all argument dtypes.
|
||||
|
||||
import torch
|
||||
import _torch_mlir as m
|
||||
|
||||
|
@ -10,11 +12,12 @@ t0 = torch.randn((4,4))
|
|||
t1 = torch.randn((4,4))
|
||||
|
||||
mb = m.ModuleBuilder()
|
||||
with mb.capture_function("foobar") as c:
|
||||
with mb.capture_function("foobar", [t0, t1]) as c:
|
||||
result = t0 + t1
|
||||
|
||||
# CHECK: module {
|
||||
# CHECK: func @foobar() {
|
||||
# CHECK: func @foobar(%arg0: !numpy.ndarray<[4,4]:f32>, %arg1: !numpy.ndarray<[4,4]:f32>) {
|
||||
# CHECK: return
|
||||
# CHECK: }
|
||||
# CHECK: }
|
||||
print(mb)
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
/*===-- npcomp-c/Types.h - NPComp custom types --------------------*- C -*-===*\
|
||||
|* *|
|
||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||
|* Exceptions. *|
|
||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
|
||||
#ifndef NPCOMP_C_TYPES_H
|
||||
#define NPCOMP_C_TYPES_H
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the Python "bool" type. */
|
||||
int npcompTypeIsABool(MlirType t);
|
||||
|
||||
/** Gets the Python "bool" type. */
|
||||
MlirType npcompBoolTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is the special "any dtype" type that is used
|
||||
* to signal an NDArray or tensor of unknown type. */
|
||||
int npcompTypeIsAAnyDtype(MlirType t);
|
||||
|
||||
/** Gets the "any dtype" type. */
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context);
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
/*============================================================================*/
|
||||
|
||||
/** Checks whether the given type is an NdArray type. */
|
||||
int npcompTypeIsANdArray(MlirType t);
|
||||
|
||||
/** Gets a numpy.NdArray type that is ranked. Any dimensions that are -1 are
|
||||
* unknown. */
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_TYPES_H
|
|
@ -1,8 +1,10 @@
|
|||
add_mlir_library(NPCOMPCAPIRegistration
|
||||
add_mlir_library(NPCOMPCAPIIR
|
||||
Registration.cpp
|
||||
|
||||
Types.cpp
|
||||
EXCLUDE_FROM_LIBMLIR
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
NPCOMPInitAll
|
||||
NPCOMPBasicpyDialect
|
||||
NPCOMPNumpyDialect
|
||||
)
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
//===- Types.cpp - C Interface for NPComp types ---------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/Types.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||
|
||||
using namespace mlir::NPCOMP::Basicpy;
|
||||
using namespace mlir::NPCOMP::Numpy;
|
||||
|
||||
/*============================================================================*/
|
||||
/* Bool type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsABool(MlirType t) { return unwrap(t).isa<BoolType>(); }
|
||||
|
||||
MlirType npcompBoolTypeGet(MlirContext context) {
|
||||
return wrap(BoolType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* Any dtype type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsAAnyDtype(MlirType t) { return unwrap(t).isa<AnyDtypeType>(); }
|
||||
|
||||
MlirType npcompAnyDtypeTypeGet(MlirContext context) {
|
||||
return wrap(AnyDtypeType::get(unwrap(context)));
|
||||
}
|
||||
|
||||
/*============================================================================*/
|
||||
/* NDArray type. */
|
||||
/*============================================================================*/
|
||||
|
||||
int npcompTypeIsANdArray(MlirType t) { return unwrap(t).isa<NdArrayType>(); }
|
||||
|
||||
MlirType npcompNdArrayTypeGetRanked(intptr_t rank, const int64_t *shape,
|
||||
MlirType elementType) {
|
||||
llvm::ArrayRef<int64_t> shapeArray(shape, rank);
|
||||
return wrap(NdArrayType::get(unwrap(elementType), shapeArray));
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
set(LLVM_LINK_COMPONENTS
|
||||
Core
|
||||
Support
|
||||
)
|
||||
|
||||
add_llvm_executable(npcomp-capi-ir-test
|
||||
ir.c
|
||||
)
|
||||
llvm_update_compile_flags(npcomp-capi-ir-test)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
target_link_libraries(npcomp-capi-ir-test
|
||||
PRIVATE
|
||||
NPCOMPCAPIIR
|
||||
MLIRCAPIRegistration
|
||||
${dialect_libs})
|
|
@ -0,0 +1,70 @@
|
|||
/*===- ir.c - Simple test of C APIs ---------------------------------------===*\
|
||||
|* *|
|
||||
|* Part of the LLVM Project, under the Apache License v2.0 with LLVM *|
|
||||
|* Exceptions. *|
|
||||
|* See https://llvm.org/LICENSE.txt for license information. *|
|
||||
|* SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception *|
|
||||
|* *|
|
||||
\*===----------------------------------------------------------------------===*/
|
||||
|
||||
/* RUN: npcomp-capi-ir-test 2>&1 | FileCheck %s
|
||||
*/
|
||||
|
||||
#include "mlir-c/IR.h"
|
||||
#include "mlir-c/Registration.h"
|
||||
#include "npcomp-c/Registration.h"
|
||||
#include "npcomp-c/Types.h"
|
||||
|
||||
#include <assert.h>
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
// Dumps an instance of all NPComp types.
|
||||
static int printStandardTypes(MlirContext ctx) {
|
||||
// Bool type.
|
||||
MlirType boolType = npcompBoolTypeGet(ctx);
|
||||
if (!npcompTypeIsABool(boolType))
|
||||
return 1;
|
||||
mlirTypeDump(boolType);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Any dtype.
|
||||
MlirType anyDtype = npcompAnyDtypeTypeGet(ctx);
|
||||
if (!npcompTypeIsAAnyDtype(anyDtype))
|
||||
return 2;
|
||||
mlirTypeDump(anyDtype);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// Ranked NdArray.
|
||||
int64_t fourDim = 4;
|
||||
MlirType rankedNdArray = npcompNdArrayTypeGetRanked(1, &fourDim, boolType);
|
||||
if (!npcompTypeIsANdArray(rankedNdArray))
|
||||
return 3;
|
||||
mlirTypeDump(rankedNdArray);
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main() {
|
||||
MlirContext ctx = mlirContextCreate();
|
||||
mlirRegisterAllDialects(ctx);
|
||||
npcompRegisterAllDialects(ctx);
|
||||
|
||||
// clang-format off
|
||||
// CHECK-LABEL: @types
|
||||
// CHECK: !basicpy.BoolType
|
||||
// CHECK: !numpy.any_dtype
|
||||
// CHECK: !numpy.ndarray<[4]:!basicpy.BoolType>
|
||||
// CHECK: 0
|
||||
// clang-format on
|
||||
fprintf(stderr, "@types\n");
|
||||
int errcode = printStandardTypes(ctx);
|
||||
fprintf(stderr, "%d\n", errcode);
|
||||
|
||||
mlirContextDestroy(ctx);
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
config.suffixes.add('.c')
|
|
@ -1,3 +1,5 @@
|
|||
add_subdirectory(CAPI)
|
||||
|
||||
llvm_canonicalize_cmake_booleans(
|
||||
NPCOMP_ENABLE_IREE
|
||||
)
|
||||
|
@ -11,6 +13,7 @@ configure_lit_site_cfg(
|
|||
|
||||
set(NPCOMP_TEST_DEPENDS
|
||||
FileCheck count not
|
||||
npcomp-capi-ir-test
|
||||
npcomp-opt
|
||||
npcomp-run-mlir
|
||||
NPCOMPNativePyExt
|
||||
|
|
|
@ -41,9 +41,7 @@ llvm_config.use_default_substitutions()
|
|||
# excludes: A list of directories to exclude from the testsuite. The 'Inputs'
|
||||
# subdirectories contain auxiliary inputs for various tests in their parent
|
||||
# directories.
|
||||
config.excludes = [
|
||||
'lit.cfg.py', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'
|
||||
]
|
||||
config.excludes = ['lit.cfg.py', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt']
|
||||
|
||||
# test_source_root: The root path where tests are located.
|
||||
config.test_source_root = os.path.dirname(__file__)
|
||||
|
@ -52,26 +50,25 @@ config.test_source_root = os.path.dirname(__file__)
|
|||
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test')
|
||||
config.npcomp_tools_dir = os.path.join(config.npcomp_obj_root, 'tools')
|
||||
config.npcomp_runtime_shlib = os.path.join(
|
||||
config.npcomp_obj_root,
|
||||
'lib',
|
||||
'libNPCOMPCompilerRuntimeShlib' + config.llvm_shlib_ext
|
||||
)
|
||||
config.npcomp_obj_root, 'lib',
|
||||
'libNPCOMPCompilerRuntimeShlib' + config.llvm_shlib_ext)
|
||||
|
||||
# Tweak the PATH and PYTHONPATH to include the tools dir.
|
||||
llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True)
|
||||
llvm_config.with_environment('PYTHONPATH', [
|
||||
os.path.join(config.npcomp_obj_root, "python")
|
||||
],
|
||||
llvm_config.with_environment('PYTHONPATH',
|
||||
[os.path.join(config.npcomp_obj_root, "python")],
|
||||
append_path=True)
|
||||
|
||||
tool_dirs = [
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
|
||||
config.llvm_tools_dir,
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-opt'),
|
||||
os.path.join(config.npcomp_tools_dir, 'npcomp-run-mlir'),
|
||||
os.path.join(config.test_exec_root, 'CAPI'),
|
||||
config.llvm_tools_dir,
|
||||
]
|
||||
tools = [
|
||||
'npcomp-opt',
|
||||
'npcomp-run-mlir',
|
||||
'npcomp-capi-ir-test',
|
||||
ToolSubst('%npcomp_runtime_shlib', config.npcomp_runtime_shlib),
|
||||
]
|
||||
|
||||
|
|
Loading…
Reference in New Issue