Remove old RefBackend

It is superceded by the new one.
pull/319/head
Sean Silva 2021-09-22 22:25:03 +00:00
parent f9c48d0b89
commit a25163fbfa
73 changed files with 1 additions and 4645 deletions

View File

@ -22,7 +22,6 @@ endif()
# Options and settings
#-------------------------------------------------------------------------------
set(NPCOMP_MINIMUM_PYTHON_VERSION 3.6)
option(NPCOMP_ENABLE_REFJIT "Enables the reference JIT backend." ON)
set(NPCOMP_IREE_BUILDDIR "../iree-build" CACHE STRING "If building IREE, then setting this elects to build from a source directory (versus installed package)")
# Turn on -gsplit-dwarf if requested in debug builds.
@ -201,11 +200,6 @@ set(NPCOMP_TABLEGEN_ARGS "")
# Optional feature selection
#-------------------------------------------------------------------------------
if(NPCOMP_ENABLE_REFJIT)
add_compile_definitions(NPCOMP_ENABLE_REFJIT)
message(STATUS "Reference JIT backend enabled")
endif()
#-------------------------------------------------------------------------------
# Subdirectories and aggregate testing targets.
#-------------------------------------------------------------------------------

View File

@ -1,99 +0,0 @@
//===-- npcomp-c/RefJITBackend.h - C API for the reference JIT ----*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
// Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_C_REFJITBACKEND_H
#define NPCOMP_C_REFJITBACKEND_H
#include <stdbool.h>
#include "mlir-c/Pass.h"
#include "mlir-c/Support.h"
#ifdef __cplusplus
extern "C" {
#endif
// Define opaque API structs.
#define DEFINE_C_API_STRUCT(name, storage) \
struct name { \
storage *ptr; \
}; \
typedef struct name name
DEFINE_C_API_STRUCT(NpcompRefJitModule, void);
DEFINE_C_API_STRUCT(NpcompRefJitValueList, void);
#undef DEFINE_C_API_STRUCT
// Must be kept in sync with C++ side.
enum NpcompRefJitElementType {
NPCOMP_REFJIT_NONE = 0,
NPCOMP_REFJIT_F32 = 1,
};
/// Populates a PassManager with a pipeline that performs backend compilation.
/// The resulting module can be passed to npcompRefJitModuleCreate().
MLIR_CAPI_EXPORTED void
npcompRefJitBuildBackendCompilationPipeline(MlirPassManager passManager,
bool optimize);
/// Creates a RefJit module from an MlirModule (as compiled from the above
/// pipeline). On success, returns a !null NpcompRefJitModule. On failure,
/// returns null and malloc() allocates an error message into *errorMessage.
/// The caller must free these messages.
MLIR_CAPI_EXPORTED NpcompRefJitModule
npcompRefJitModuleCreate(MlirModule module, MlirStringRef *sharedLibs,
intptr_t sharedLibsSize, char **errorMessage);
/// Whether the module is null.
static inline bool npcompRefJitModuleIsNull(NpcompRefJitModule m) {
return !m.ptr;
}
/// Destroys a refjit module.
MLIR_CAPI_EXPORTED void npcompRefJitModuleDestroy(NpcompRefJitModule module);
/// Invokes a function on a RefJit module. On success, returns true and malloc()
/// and adds all outputs to the passed outputs list. On failure, returns false
/// and populates *errorMessage with a malloc() allocated error message, which
/// must be caller freed.
MLIR_CAPI_EXPORTED bool
npcompRefJitModuleInvoke(NpcompRefJitModule m, MlirStringRef functionName,
NpcompRefJitValueList inputOutputs,
char **errorMessage);
/// Creates an empty value list.
MLIR_CAPI_EXPORTED NpcompRefJitValueList npcompRefJitValueListCreate();
/// Destroys a value list.
MLIR_CAPI_EXPORTED void
npcompRefJitValueListDestroy(NpcompRefJitValueList list);
/// Returns the size of the value list.
MLIR_CAPI_EXPORTED intptr_t
npcompRefJitValueListSize(NpcompRefJitValueList list);
/// Adds values to the list.
MLIR_CAPI_EXPORTED void npcompRefJitValueAddTensorCopy(
NpcompRefJitValueList list, NpcompRefJitElementType elementType,
const int32_t *extents, intptr_t extentsSize, const void *data);
// Reads Tensor from a list.
MLIR_CAPI_EXPORTED bool npcompRefJitValueIsaTensor(NpcompRefJitValueList list,
intptr_t i);
MLIR_CAPI_EXPORTED void *
npcompRefJitValueGetTensor(NpcompRefJitValueList list, intptr_t i,
NpcompRefJitElementType *elementType, intptr_t *rank,
const int32_t **extents);
#ifdef __cplusplus
}
#endif
#endif // NPCOMP_C_REFJITBACKEND_H

View File

@ -1,4 +1,3 @@
add_subdirectory(Backend)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(RefBackend)

View File

@ -1,3 +1 @@
add_subdirectory(Refback)
add_subdirectory(Refbackrt)
add_subdirectory(TorchConversion)

View File

@ -1 +0,0 @@
add_subdirectory(IR)

View File

@ -1 +0,0 @@
add_mlir_dialect(RefbackOps refback)

View File

@ -1,23 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACK_BASE
#define REFBACK_BASE
include "mlir/IR/OpBase.td"
def Refback_Dialect : Dialect {
let name = "refback";
let cppNamespace = "::mlir::NPCOMP::refback";
let description = [{
Ops used by the reference backend as part of its lowering.
}];
}
#endif // REFBACK_BASE

View File

@ -1,16 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H
#define NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H
#include "mlir/IR/Dialect.h"
#include "npcomp/Dialect/Refback/IR/RefbackOpsDialect.h.inc"
#endif // NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H

View File

@ -1,22 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
#define NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/Refback/IR/RefbackOps.h.inc"
#endif // NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H

View File

@ -1,40 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACK_OPS
#define REFBACK_OPS
include "npcomp/Dialect/Refback/IR/RefbackBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
class Refback_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Refback_Dialect, mnemonic, traits> {
}
//===----------------------------------------------------------------------===//
// Ops related to bufferization.
//===----------------------------------------------------------------------===//
def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> {
let summary = "Allocates a memref of the given shape.";
let description = [{
Allocates a memref of the given shape.
This op is a convenience for creating a bunch of
tensor.extract ops + std.alloc.
}];
let arguments = (ins Shape_ExtentTensorType:$shape);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
#endif // REFBACK_OPS

View File

@ -1 +0,0 @@
add_subdirectory(IR)

View File

@ -1 +0,0 @@
add_mlir_dialect(RefbackrtOps refbackrt)

View File

@ -1,26 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACKRT_BASE
#define REFBACKRT_BASE
include "mlir/IR/OpBase.td"
def Refbackrt_Dialect : Dialect {
let name = "refbackrt";
let cppNamespace = "::mlir::NPCOMP::refbackrt";
let description = [{
The `refbackrt` dialect is the IR manifestation for interaction with the
reference backend runtime. It primarily serves as a layer that enapsulates the
data structures and functions available in the runtime, and faciliates
conversion to those conventions, such as by providing utilities for being
lowered to the llvm dialect.
}];
}
#endif // #ifndef REFBACKRT_BASE

View File

@ -1,31 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
#include "mlir/IR/Dialect.h"
namespace mlir {
namespace NPCOMP {
namespace refbackrt {
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
public:
using Base::Base;
static TensorType get(MLIRContext *context) { return Base::get(context); }
};
} // namespace refbackrt
} // namespace NPCOMP
} // namespace mlir
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.h.inc"
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H

View File

@ -1,20 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#define GET_OP_CLASSES
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h.inc"
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H

View File

@ -1,119 +0,0 @@
//===-------------------------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACKRT_OPS
#define REFBACKRT_OPS
include "npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td"
include "mlir/IR/SymbolInterfaces.td"
class Refbackrt_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Refbackrt_Dialect, mnemonic, traits> {
}
def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
let summary = "Aborts if the predicate is true";
let description = [{
Aborts if the predicate is true.
}];
let arguments = (ins I1:$pred, StrAttr:$msg);
let results = (outs);
let assemblyFormat = "$pred `,` $msg attr-dict";
}
def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
]> {
let summary = "Global metadata for the module";
let description = [{
This op contains a region containing refbackrt.func_metadata ops,
which give information about the functions in the module. This allows
the module to be introspected when it is loaded, such as looking up
functions.
Future uses are checking how many results functions should have, or
what their argument types are expected to be to provide clean and safe
errors when invocations fail.
TODO: Verify that there should be no more than one of these ops in a
module.
This op is designed to hold a region, which makes it easy to convert to
a single LLVM global with a single conversion pattern.
}];
let arguments = (ins);
let results = (outs);
let regions = (region SizedRegion<1>:$metadatas);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Refbackrt_ModuleMetadataTerminatorOp
: Refbackrt_Op<"module_metadata_terminator",
[Terminator, HasParent<"ModuleMetadataOp">]> {
let summary = "Implicit terminator for ModuleMetadataOp's region";
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
}
def Refbackrt_FuncMetadataOp
: Refbackrt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> {
let summary = "Runtime metadata for a single func";
let description = [{
Runtime metadata for a single func.
Contains type / shape information for arguments as described below:
* ArgType(s):
Integer value from `CompilerDataStructures.h` for each argument
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
* ElementType(s):
Certain input ArgType's also have an element type (e.g. Tensor<float>,
List<int>, etc.)
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
* Rank(s):
Integer value indicating the rank for each argument.
* Shape(s):
Flattened hyper-rectangular representation of the shapes for each argument.
Since each shape's size varies based on the Rank, we pad out the shapes
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
for details.
Shapes Example:
constexpr int kMaxRank = 6;
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
inputShapes = dense<...> : tensor<12xi32>
// 2 shapes with 6 elements each so that the LowerToLLVM pass
// where only the first `rank` values in each shape are valid.
//
// can update the struct(s) by just grabbing a pointer at
// %shape_ptr = %base + (kMaxRank * argIndex)
}];
let arguments = (ins
FlatSymbolRefAttr:$funcName,
I32Attr:$numInputs,
I32Attr:$numOutputs,
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
OptionalAttr<I32ElementsAttr>:$inputRanks,
OptionalAttr<I32ElementsAttr>:$inputShapes,
// I32ElementsAttr:$inputIsStatic,
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
OptionalAttr<I32ElementsAttr>:$outputRanks,
OptionalAttr<I32ElementsAttr>:$outputShapes
//I32ElementsAttr:$outputIsStatic
);
let results = (outs);
let assemblyFormat = "attr-dict";
let verifier = [{ return ::verify(*this); }];
}
#endif // #ifndef REFBACKRT_OPS

View File

@ -1,5 +0,0 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(NPCOMPRefBackendPassIncGen)
add_mlir_doc(Passes RefBackendPasses ./ -gen-pass-doc)

View File

@ -1,54 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_JITRUNTIME_JITMODULE_H
#define NPCOMP_JITRUNTIME_JITMODULE_H
#include "mlir/ExecutionEngine/ExecutionEngine.h"
#include "mlir/IR/BuiltinOps.h"
#include "npcomp/RefBackend/Runtime/UserAPI.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Error.h"
#include <memory>
namespace mlir {
class PassManager;
} // namespace mlir
namespace refback {
// Wrapper around refbackrt data structures and a JITted module, facilitating
// interaction.
class JITModule {
public:
/// Populates a PassManager with a pipeline that performs backend compilation.
/// The resulting module can be passed to fromCompiledModule().
static void buildBackendCompilationPipeline(mlir::PassManager &pm,
bool optimize = false);
/// Constructs a JITModule from a compiled Module.
/// The module should be the result of having run the backend compilation
/// pipeline successfully.
static llvm::Expected<std::unique_ptr<JITModule>>
fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs);
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::RtValue> inputs);
private:
JITModule();
std::unique_ptr<mlir::ExecutionEngine> engine;
refbackrt::ModuleDescriptor *descriptor;
};
} // namespace refback
#endif // NPCOMP_JITRUNTIME_JITMODULE_H

View File

@ -1,9 +0,0 @@
Utilities for compiling and running on the reference backend with a JIT.
The runtime itself lives in {include,lib}/RefBackend/Runtime, but since it
is totally firewalled from the compiler codebase, it presents a fairly
bare-bones interface (e.g. it doesn't use libSupport, can't use LLVM's JIT
interfaces, etc.).
The interface provided in this directory uses standard LLVM conventions and
freely relies on libSupport, JIT utilities, etc.

View File

@ -1,40 +0,0 @@
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_REFBACKEND_PASSES
#define NPCOMP_REFBACKEND_PASSES
include "mlir/Pass/PassBase.td"
def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
let summary = "Lower constructs requiring runtime support to `refbackrt`";
let description = [{
We have a specialized dialect `refbackrt` which models our runtime's data
structures, and function signatures (and presumably eventually, other
ABI boundaries like external calls if we ever support it) will be
converted.
The constructs requiring runtime support are:
- function signatures / module metadata
- error handling
}];
let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";
}
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
let summary = "Lower AllocMemRefOp's";
let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()";
let dependentDialects = ["tensor::TensorDialect", "memref::MemRefDialect"];
}
def LowerToLLVM : Pass<"refback-lower-to-llvm", "ModuleOp"> {
let summary = "Lower everything to LLVM";
let constructor = "mlir::NPCOMP::createLowerToLLVMPass();";
}
#endif // NPCOMP_REFBACKEND_PASSES

View File

@ -1,51 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_REFBACKEND_REFBACKEND_H
#define NPCOMP_REFBACKEND_REFBACKEND_H
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
namespace mlir {
namespace NPCOMP {
/// Registers all RefBackend passes.
void registerRefBackendPasses();
// Look in createRefBackendLoweringPipeline for more information about how these
// passes fit together.
//
// Pass summaries are in Passes.td.
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
std::unique_ptr<Pass> createRestrictedCanonicalizerPass();
struct RefBackendLoweringPipelineOptions
: public PassPipelineOptions<RefBackendLoweringPipelineOptions> {
// If this option is true, then perform optimizations.
// If this option is false, only do the bare minimum for correctness.
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
llvm::cl::init(false)};
};
// The main pipeline that encapsulates the full RefBackend lowering.
void createRefBackendLoweringPipeline(
OpPassManager &pm, const RefBackendLoweringPipelineOptions &options);
} // namespace NPCOMP
} // namespace mlir
#endif // NPCOMP_REFBACKEND_REFBACKEND_H

View File

@ -1,14 +0,0 @@
Refbackrt (namespace `refbackrt`) is the runtime support library for the
RefBackend backend. It is best practice to keep compiler and runtime code
totally firewalled.
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
As such, this directory should have NO DEPENDENCIES ON COMPILER CODE (no
LLVM libSupport, etc.).
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
This will cause some duplication, but history has shown that this
firewalling pays big dividends. In particular, compiler code very
frequently has binary sizes that are simply unacceptable in runtime
scenarios, such as MByte-sized dependencies like LLVM libSupport.
Runtime code should fit in kBytes.

View File

@ -1,101 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Trimmed down support classes intended to provide a familiar LLVM-like API,
// but without actually pulling in the LLVM ones.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_RUNTIME_SUPPORT_H
#define NPCOMP_RUNTIME_SUPPORT_H
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
namespace refbackrt {
class StringRef {
public:
StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){};
// Construct from NUL-terminated C string.
StringRef(const char *ptr) : ptr(ptr), length(std::strlen(ptr)) {}
bool equals(StringRef other) {
if (length != other.length)
return false;
return std::memcmp(ptr, other.ptr, length) == 0;
}
const char* str() { return ptr; }
private:
const char *ptr;
std::size_t length;
};
inline bool operator==(StringRef lhs, StringRef rhs) { return lhs.equals(rhs); }
inline bool operator!=(StringRef lhs, StringRef rhs) {
return !operator==(lhs, rhs);
}
template <typename T> class ArrayRef {
public:
ArrayRef(const T *ptr, std::size_t length) : ptr(ptr), length(length){};
const T &operator[](std::size_t i) const {
assert(i < length);
return ptr[i];
}
const T *data() const { return ptr; }
std::size_t size() const { return length; }
private:
const T *ptr;
std::size_t length;
};
template <typename T> class MutableArrayRef {
public:
MutableArrayRef(T *ptr, std::size_t length) : ptr(ptr), length(length){};
T &operator[](std::size_t i) {
assert(i < length);
return ptr[i];
}
T *data() const { return ptr; }
std::size_t size() const { return length; }
private:
T *ptr;
std::size_t length;
};
// Literally copied from MLIR.
struct LogicalResult {
enum ResultEnum { Success, Failure } value;
LogicalResult(ResultEnum v) : value(v) {}
};
inline LogicalResult success(bool isSuccess = true) {
return LogicalResult{isSuccess ? LogicalResult::Success
: LogicalResult::Failure};
}
inline LogicalResult failure(bool isFailure = true) {
return LogicalResult{isFailure ? LogicalResult::Failure
: LogicalResult::Success};
}
inline bool succeeded(LogicalResult result) {
return result.value == LogicalResult::Success;
}
inline bool failed(LogicalResult result) {
return result.value == LogicalResult::Failure;
}
} // namespace refbackrt
#endif // NPCOMP_RUNTIME_SUPPORT_H

View File

@ -1,421 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the public-facing interface for interacting with the npcomp
// runtime.
//
// This functionality is totally firewalled from the compiler codebase, so
// even if things superficially look similar, remember that there are no
// LLVM utilities here, memory allocation should be kept to a minimum, etc.
//
// npcomp/RefBackend/Runtime/Support.h provides some minimal LLVM-like support
// code to keep the API familiar.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_RUNTIME_USERAPI_H
#define NPCOMP_RUNTIME_USERAPI_H
#include "npcomp/RefBackend/Runtime/Support.h"
#include <array>
#include <atomic>
#include <cstdlib>
namespace refbackrt {
struct RtValue;
// Base class for any RefCounted object type
class RefTarget {
protected:
template <typename T> friend class Ref;
mutable std::atomic<size_t> refCount;
constexpr RefTarget() noexcept : refCount(0) {}
};
// Reference-counted handle to a type with a `refCount` member.
template <typename T> class Ref {
public:
Ref() { ptr = nullptr; }
// Creates a Ref and increments the refcount by 1.
// rawPtr must be allocated with std::malloc.
Ref(T *rawPtr) {
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
ptr = rawPtr;
incref(ptr);
}
Ref(const Ref &other) {
ptr = other.ptr;
incref(ptr);
}
Ref(Ref &&other) { ptr = other.takePtr(); }
Ref &operator=(const Ref &other) {
if (&other == this)
return *this;
decref(ptr);
ptr = other.ptr;
incref(ptr);
return *this;
}
Ref &operator=(Ref &&other) {
if (&other == this)
return *this;
decref(ptr);
ptr = other.takePtr();
return *this;
}
~Ref() { decref(ptr); }
T &operator*() const { return *ptr; }
T *operator->() const { return ptr; }
T *get() const { return ptr; }
T *takePtr() {
auto *ret = ptr;
ptr = nullptr;
return ret;
}
int debugGetRefCount() { return ptr->refCount; }
private:
friend struct RtValue;
static void incref(T *ptr) {
if (!ptr)
return;
ptr->refCount += 1;
}
friend struct RtValue;
static void decref(T *ptr) {
if (!ptr)
return;
if (ptr->refCount.fetch_sub(1) == 1) {
ptr->~T();
std::free(static_cast<void *>(ptr));
}
}
T *ptr;
};
// The available data types.
enum class ElementType : std::int32_t {
NONE,
F32,
};
std::int32_t getElementTypeByteSize(ElementType type);
StringRef getElementTypeAsStringRef(ElementType type);
// Representation of a tensor.
class Tensor : public RefTarget {
public:
// Due to tail-allocated objects, this struct should never be directly
// constructed.
Tensor() = delete;
// Create a Tensor with the given extents and element type, with a buffer
// holding a copy of `data`.
static Ref<Tensor> create(ArrayRef<std::int32_t> extents,
ElementType elementType, void *data);
// Same as `create`, but returns a raw pointer.
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
ElementType elementType, void *data);
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
// Same as `create`, but returns a raw pointer.
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
ElementType elementType, void *data);
ElementType getElementType() const { return elementType; }
std::int32_t getRank() const { return rank; }
void *getData() const { return data; }
template <typename T> T *getData() const { return static_cast<T *>(data); }
std::int32_t getExtent(int dimension) const {
return getExtents()[dimension];
}
ArrayRef<std::int32_t> getExtents() const {
auto extents = const_cast<Tensor *>(this)->getMutableExtents();
return ArrayRef<std::int32_t>(extents.data(), extents.size());
}
// Returns the number of bytes occupied by the data representing this tensor.
// The total allocated amount might be higher to allow e.g. for alignment
// nudging.
std::int32_t getDataByteSize() const;
~Tensor() { std::free(allocatedPtr); }
private:
MutableArrayRef<std::int32_t> getMutableExtents() {
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
return MutableArrayRef<std::int32_t>(tail, rank);
}
ElementType elementType;
// The number of dimensions of this Tensor.
// There are `rank` tail-allocated std::int32_t values representing the
// tensor extents.
std::int32_t rank;
// The buffer base.
void *data;
// The raw pointer returned by the allocator (currently assumed to be
// malloc), suitable for freeing the buffer.
void *allocatedPtr;
// Sizes are tail-allocated.
};
// RtValue is a generic tagged union used to hold all value types
// The tag determines the type, and the payload represents the stored
// contents of an object. If an object is not trivially destructible,
// then it must be refcounted and must have a refCount.
#define NPCOMP_FORALL_PRIM_TAGS(_) \
_(None) \
_(Bool) \
_(Int) \
_(Float) \
_(Double)
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
#define NPCOMP_FORALL_TAGS(_) \
NPCOMP_FORALL_PRIM_TAGS(_) \
NPCOMP_FORALL_REF_TAGS(_)
struct RtValue final {
RtValue() : payload{0}, tag(Tag::None) {}
// Bool
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
bool isBool() const { return Tag::Bool == tag; }
bool toBool() const {
assert(isBool());
return payload.asBool;
}
// Int
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
bool isInt() const { return Tag::Int == tag; }
int64_t toInt() const {
assert(isInt());
return payload.asInt;
}
// Float
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
bool isFloat() const { return Tag::Float == tag; }
float toFloat() const {
assert(isFloat());
return payload.asFloat;
}
// Double
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
bool isDouble() const { return Tag::Double == tag; }
double toDouble() const {
assert(isDouble());
return payload.asDouble;
}
// Tensor
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
}
bool isTensor() const { return Tag::Tensor == tag; }
Ref<Tensor> toTensor() const {
assert(isTensor());
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
}
// Ref
bool isRef() const {
#define DEFINE_IS_REF(x) \
if (is##x()) { \
return true; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
#undef DEFINE_IS_REF
return false;
}
// Scalar
bool isScalar() const {
return isBool() || isInt() || isFloat() || isDouble();
}
// RtValue (downcast)
const RtValue &toRtValue() const { return *this; }
RtValue &toRtValue() { return *this; }
// Stringify tag for debugging.
StringRef tagKind() const {
switch (tag) {
#define DEFINE_CASE(x) \
case Tag::x: \
return #x;
NPCOMP_FORALL_TAGS(DEFINE_CASE)
#undef DEFINE_CASE
}
// TODO(brycearden): Print tag here
return "InvalidTag!";
}
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
if (isRef()) {
#define DEFINE_INCREF(x) \
if (is##x()) { \
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
#undef DEFINE_INCREF
assert(false && "Unsupported RtValue type");
}
}
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
RtValue &operator=(RtValue &&rhs) & noexcept {
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
return *this;
}
RtValue &operator=(RtValue const &rhs) & {
RtValue(rhs).swap(*this);
return *this;
}
~RtValue() {
if (isRef()) {
#define DEFINE_DECREF(x) \
if (is##x()) { \
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
return; \
}
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
#undef DEFINE_DECREF
assert(false && "Unsupported RtValue type");
}
}
private:
void swap(RtValue &rhs) {
std::swap(payload, rhs.payload);
std::swap(tag, rhs.tag);
}
// NOTE: Runtime tags are intentionally private.
// Please use the helper functions above to query information about the type
// of a RtValue.
enum class Tag : std::uint32_t {
#define DEFINE_TAG(x) x,
NPCOMP_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
};
union Payload {
bool asBool;
int64_t asInt;
float asFloat;
double asDouble;
void *asVoidPtr;
};
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
Payload payload;
Tag tag;
};
//===----------------------------------------------------------------------===//
// Module loading.
// This is the main entry point that users interact with.
//===----------------------------------------------------------------------===//
enum class ArgType : std::uint32_t {
kNone = 0,
kTensor,
kF32,
kF64,
};
StringRef getArgTypeAsStringRef(ArgType type);
// Maximum rank supported across the ABI boundary
constexpr static int kMaxRank = 6;
struct InputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
};
struct OutputArgInfo {
// What type of argument this is
ArgType argType;
// Certain arg types also have an element type
ElementType elementType;
std::int32_t rank;
std::array<std::int32_t, kMaxRank> extents;
// TODO(brycearden): Add checks for whether output buffers alias to input
// buffers and populate field(s) here indicating that case
};
// Maximum input or output arity.
constexpr static int kMaxArity = 20;
// Metadata for a particular function.
struct FunctionMetadata {
std::int32_t numInputs;
std::int32_t numOutputs;
std::array<InputArgInfo, kMaxArity> inputArgInfos;
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
};
// Opaque forward declaration of module descriptor type. This is the type
// created by the compiler in the module binary.
struct ModuleDescriptor;
// Verifies that the input RtValue arg types match what the user provides
// matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info);
// Verifies that the input RtValue shapes matches what the user provides
// matches the types we expect from the descriptors emitted by the
// compiler.
//
// Returns failure if the input type(s) are not valid
LogicalResult checkRtValueShapes(const RtValue &value,
const InputArgInfo &info);
// Creates an RtValue of the right type from the output metadata
// provided by the compiled module
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
// Low-level invocation API. The number of inputs and outputs should be correct
// and match the results of getMetadata.
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
// Metadata for function `functionName`.
//
// Returns failure if functionName wasn't found.
LogicalResult getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName,
FunctionMetadata &outMetadata);
} // namespace refbackrt
#endif // NPCOMP_RUNTIME_USERAPI_H

View File

@ -6,7 +6,6 @@ set(LLVM_LINK_COMPONENTS
add_npcomp_library(NPCOMPCAPI
InitLLVM.cpp
RefJITBackend.cpp
Registration.cpp
LINK_LIBS PUBLIC
@ -14,8 +13,6 @@ add_npcomp_library(NPCOMPCAPI
MLIRLLVMIR
MLIRTargetLLVMIRExport
NPCOMPInitAll
NPCOMPRefBackendJITHelpers
NPCOMPRuntime
TorchMLIRTorchDialect
TorchMLIRInitAll

View File

@ -1,127 +0,0 @@
//===- RefJITBackend.cpp - CAPI for RefJit --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp-c/RefJITBackend.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Pass.h"
#include "mlir/CAPI/Wrap.h"
#include "mlir/Pass/PassManager.h"
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
#include "llvm/ADT/Optional.h"
using namespace llvm;
using namespace mlir;
using namespace refback;
using namespace refbackrt;
using ValueListCpp = SmallVector<RtValue, 4>;
DEFINE_C_API_PTR_METHODS(NpcompRefJitModule, JITModule)
DEFINE_C_API_PTR_METHODS(NpcompRefJitValueList, ValueListCpp)
static_assert(static_cast<int>(ElementType::F32) == NPCOMP_REFJIT_F32,
"mismatched F32 mapping");
namespace {
template <typename T>
static Optional<T> checkError(llvm::Expected<T> &&expected,
char **errorMessageCstr, Twine banner = {}) {
if (LLVM_LIKELY(expected))
return std::move(*expected);
std::string errorMessage;
llvm::raw_string_ostream os(errorMessage);
llvm::logAllUnhandledErrors(expected.takeError(), os, banner);
os.flush();
*errorMessageCstr = strdup(errorMessage.c_str());
return llvm::None;
}
} // namespace
void npcompRefJitBuildBackendCompilationPipeline(MlirPassManager passManager,
bool optimize) {
JITModule::buildBackendCompilationPipeline(*unwrap(passManager), optimize);
}
NpcompRefJitModule npcompRefJitModuleCreate(MlirModule moduleOp,
MlirStringRef *sharedLibs,
intptr_t sharedLibsSize,
char **errorMessage) {
SmallVector<llvm::StringRef> sharedLibsCpp;
for (intptr_t i = 0; i < sharedLibsSize; ++i) {
sharedLibsCpp.push_back(
llvm::StringRef(sharedLibs[i].data, sharedLibs[i].length));
}
auto refJitModuleCpp =
checkError(JITModule::fromCompiledModule(unwrap(moduleOp), sharedLibsCpp),
errorMessage, "error creating refjit module");
if (!refJitModuleCpp)
return {nullptr};
return wrap(refJitModuleCpp->release());
}
void npcompRefJitModuleDestroy(NpcompRefJitModule module) {
delete unwrap(module);
}
bool npcompRefJitModuleInvoke(NpcompRefJitModule m, MlirStringRef functionName,
NpcompRefJitValueList inputOutputs,
char **errorMessage) {
ValueListCpp *ioList = unwrap(inputOutputs);
auto results = checkError(
unwrap(m)->invoke(llvm::StringRef(functionName.data, functionName.length),
*ioList),
errorMessage, "error invoking function");
ioList->clear();
if (!results)
return false;
for (int i = 0, e = results->size(); i < e; ++i) {
ioList->push_back(std::move((*results)[i]));
}
return true;
}
NpcompRefJitValueList npcompRefJitValueListCreate() {
return wrap(new ValueListCpp());
}
void npcompRefJitValueListDestroy(NpcompRefJitValueList list) {
delete unwrap(list);
}
intptr_t npcompRefJitValueListSize(NpcompRefJitValueList list) {
return unwrap(list)->size();
}
void npcompRefJitValueAddTensorCopy(NpcompRefJitValueList list,
NpcompRefJitElementType elementType,
const int32_t *extents,
intptr_t extentsSize, const void *data) {
ElementType elementTypeCpp = static_cast<ElementType>(elementType);
auto tensor =
Tensor::create(refbackrt::ArrayRef<std::int32_t>(extents, extentsSize),
elementTypeCpp, const_cast<void *>(data));
unwrap(list)->push_back(std::move(tensor));
}
bool npcompRefJitValueIsaTensor(NpcompRefJitValueList list, intptr_t i) {
return (*unwrap(list))[i].isTensor();
}
void *npcompRefJitValueGetTensor(NpcompRefJitValueList list, intptr_t i,
NpcompRefJitElementType *elementType,
intptr_t *rank, const int32_t **extents) {
auto tensor = (*unwrap(list))[i].toTensor();
*elementType = static_cast<NpcompRefJitElementType>(tensor->getElementType());
*rank = tensor->getRank();
*extents = tensor->getExtents().data();
return tensor->getData();
}

View File

@ -3,7 +3,6 @@ add_subdirectory(CAPI)
add_subdirectory(Conversion)
add_subdirectory(Dialect)
add_subdirectory(Interfaces)
add_subdirectory(RefBackend)
################################################################################
# Setup the initialization target.
@ -28,11 +27,8 @@ add_npcomp_library(NPCOMPInitAll
# Local depends
NPCOMPCommonBackend
NPCOMPIREEBackend
NPCOMPRefBackend
NPCOMPRefbackDialect
TorchMLIRTorchDialect
NPCOMPTorchConversionDialect
NPCOMPRefbackrtDialect
NPCOMPConversionPasses
IREEDialectsIREEDialect

View File

@ -1,3 +1 @@
add_subdirectory(Refback)
add_subdirectory(Refbackrt)
add_subdirectory(TorchConversion)

View File

@ -1 +0,0 @@
add_subdirectory(IR)

View File

@ -1,19 +0,0 @@
add_npcomp_dialect_library(NPCOMPRefbackDialect
RefbackDialect.cpp
RefbackOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refback
DEPENDS
MLIRRefbackOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
MLIRSideEffectInterfaces
MLIRShape
)

View File

@ -1,42 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
#include "mlir/Transforms/InliningUtils.h"
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
using namespace mlir;
using namespace mlir::NPCOMP::refback;
#include "npcomp/Dialect/Refback/IR/RefbackOpsDialect.cpp.inc"
//===----------------------------------------------------------------------===//
// RefbackDialect Dialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
struct RefbackInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
BlockAndValueMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
BlockAndValueMapping &) const final {
return true;
}
};
} // end anonymous namespace
void RefbackDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"
>();
addInterfaces<RefbackInlinerInterface>();
}

View File

@ -1,19 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/TypeUtilities.h"
#include "llvm/ADT/STLExtras.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::refback;
#define GET_OP_CLASSES
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"

View File

@ -1 +0,0 @@
add_subdirectory(IR)

View File

@ -1,17 +0,0 @@
add_npcomp_dialect_library(NPCOMPRefbackrtDialect
RefbackrtDialect.cpp
RefbackrtOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refbackrt
DEPENDS
MLIRRefbackrtOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
)

View File

@ -1,25 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::NPCOMP::refbackrt;
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.cpp.inc"
void RefbackrtDialect::initialize() {
addOperations<
#define GET_OP_LIST
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
>();
addTypes<TensorType>();
}

View File

@ -1,70 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP::refbackrt;
//===----------------------------------------------------------------------===//
// ModuleMetadataOp
//===----------------------------------------------------------------------===//
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
p.printOptionalAttrDictWithKeyword(op->getAttrs());
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
}
static ParseResult parseModuleMetadataOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
auto *body = result.addRegion();
if (parser.parseRegion(*body, llvm::None, llvm::None))
return failure();
ModuleMetadataOp::ensureTerminator(*body, parser.getBuilder(),
result.location);
return success();
}
//===----------------------------------------------------------------------===//
// FuncMetadataOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(FuncMetadataOp op) {
auto *module = op->getParentOp()->getParentOp();
auto func = dyn_cast_or_null<FuncOp>(
SymbolTable::lookupSymbolIn(module, op.funcName()));
if (!func)
return op.emitError() << "must reference a valid func";
if (op.numInputs() != func.getNumArguments())
return op.emitError() << "must agree on number of inputs";
if (op.numOutputs() != func.getNumResults())
return op.emitError() << "must agree on number of outputs";
if (op.numInputs() > 0) {
if (op.numInputs() != op.inputArgTypes()->size()) {
return op.emitError() << "number of inputTypes must match number of inputs";
}
}
if (op.numOutputs() > 0) {
if (op.numOutputs() != op.outputArgTypes()->size())
return op.emitError() << "number of outputTypes must match number of outputs";
}
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"

View File

@ -13,23 +13,17 @@
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Backend/IREE/Passes.h"
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
#include "npcomp/RefBackend/RefBackend.h"
void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry &registry) {
// clang-format off
registry.insert<refbackrt::RefbackrtDialect,
refback::RefbackDialect,
mlir::NPCOMP::TorchConversion::TorchConversionDialect,
registry.insert<mlir::NPCOMP::TorchConversion::TorchConversionDialect,
iree::IREEDialect>();
// clang-format on
}
void mlir::NPCOMP::registerAllPasses() {
mlir::NPCOMP::registerRefBackendPasses();
mlir::NPCOMP::registerConversionPasses();
mlir::NPCOMP::registerTorchConversionPasses();
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();

View File

@ -1,35 +0,0 @@
add_subdirectory(Runtime)
add_subdirectory(JITHelpers)
add_npcomp_library(NPCOMPRefBackend
RefBackend.cpp
LowerToLLVM.cpp
LowerToRefbackrtABI.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SRC_DIR}/include/npcomp/RefBackend
DEPENDS
NPCOMPRefBackendPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRLinalg
MLIRLinalgToLLVM
MLIRLinalgTransforms
MLIRMathToLLVM
MLIRMathTransforms
MLIRMemRefToLLVM
MLIRSCFToStandard
MLIRSCFTransforms
MLIRShapeToStandard
MLIRStandard
MLIRStandardOpsTransforms
MLIRStandardToLLVM
MLIRTensorTransforms
)
mlir_check_all_link_libraries(NPCOMPRefBackend)

View File

@ -1,16 +0,0 @@
add_npcomp_library(NPCOMPRefBackendJITHelpers
JITModule.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SRC_DIR}/include/npcomp/RefBackend/JITHelpers
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
NPCOMPRuntime
NPCOMPRefBackend
MLIRExecutionEngine
)
mlir_check_all_link_libraries(NPCOMPRefBackend)

View File

@ -1,146 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
#include "mlir/ExecutionEngine/CRunnerUtils.h"
#include "mlir/ExecutionEngine/OptUtils.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "npcomp/RefBackend/RefBackend.h"
#include <sstream>
using namespace refback;
using namespace mlir;
using llvm::Error;
using llvm::Expected;
using llvm::StringError;
using llvm::Twine;
/// Wrap a string into an llvm::StringError.
static Error make_string_error(const Twine &message) {
return llvm::make_error<StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
JITModule::JITModule() {}
void JITModule::buildBackendCompilationPipeline(PassManager &pm,
bool optimize) {
NPCOMP::RefBackendLoweringPipelineOptions options;
options.optimize = optimize;
NPCOMP::createRefBackendLoweringPipeline(pm, options);
}
llvm::Expected<std::unique_ptr<JITModule>>
JITModule::fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs) {
// Ensure LLVM Dialect -> LLVM IR translations are available.
mlir::registerLLVMDialectTranslation(*module->getContext());
// Build the JITModule.
auto expectedEngine = ExecutionEngine::create(
module, /*llvmModuleBuilder=*/nullptr,
/*transformer=*/[](llvm::Module *) { return Error::success(); },
/*jitCodeGenOptLevel=*/llvm::None, llvm::to_vector<6>(sharedLibs));
if (!expectedEngine)
return expectedEngine.takeError();
std::unique_ptr<JITModule> ret(new JITModule);
ret->engine = std::move(*expectedEngine);
// Here we abuse mlir::ExecutionEngine a bit. It technically returns a
// function pointer, but here we look up a module descriptor.
auto expectedAddress = ret->engine->lookup("__npcomp_module_descriptor");
if (!expectedAddress)
return expectedAddress.takeError();
ret->descriptor =
reinterpret_cast<refbackrt::ModuleDescriptor *>(*expectedAddress);
return std::move(ret);
}
// Converter for bridging to refbackrt llvm-lookalike data structures.
static refbackrt::StringRef toRefbackrt(llvm::StringRef s) {
return refbackrt::StringRef(s.data(), s.size());
}
template <typename T>
static refbackrt::ArrayRef<T> toRefbackrt(llvm::ArrayRef<T> a) {
return refbackrt::ArrayRef<T>(a.data(), a.size());
}
template <typename T>
static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
}
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
static constexpr char kDynamicDimAsString[] = "?";
std::stringstream ss;
ss << "(";
for (int i = 0, e = extents.size(); i < e; i++) {
if (extents[i] < 0)
ss << kDynamicDimAsString;
else
ss << extents[i];
if (i != e - 1)
ss << "x";
}
ss << ")";
return ss.str();
}
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<refbackrt::RtValue> inputs) {
refbackrt::FunctionMetadata metadata;
if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName));
SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) +
"': expected " + Twine(metadata.numInputs) +
" inputs");
// Verify user input types and shapes match what the compiler expects
for (int i = 0; i < metadata.numInputs; i++) {
auto &input = inputs[i];
auto &inputArgInfo = metadata.inputArgInfos[i];
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) +
"': input argument type mismatch. actual (provided by user): " +
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
return make_string_error(
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
Twine(i) + "). " + "actual (provided by user): " +
stringifyShape(input.toTensor()->getExtents()) +
", expected (from compiler): " +
stringifyShape(refbackrt::ArrayRef<int32_t>(
inputArgInfo.extents.data(), inputArgInfo.rank)));
}
// Create the correct output RtValue based on FuncMetadata,
// which contains the arg types (scalar, Tensor, etc.), element types (only
// applicable if not scalar) and shapes (also only applicable if not scalar)
//
// Currently we have to give each RtValue an output type so that we know
// how to pack / unpack the outputs properly across the ABI boundary in
// refbackrt::invoke. As a result, we can't just rely on the default
// construction of each output argument type (otherwise RtValue will have
// Tag::kNone) currently without passing the ArgInfo structs down to the
// Runtime level, so we deal with the output type creation here.
for (int i = 0; i < metadata.numOutputs; i++) {
outputs[i] =
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]);
}
refbackrt::invoke(
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
return outputs;
}

View File

@ -1,747 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/RefBackend/RefBackend.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Math/Transforms/Passes.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
using mlir::LLVM::LLVMArrayType;
using mlir::LLVM::LLVMFuncOp;
using mlir::LLVM::LLVMFunctionType;
using mlir::LLVM::LLVMPointerType;
using mlir::LLVM::LLVMStructType;
using mlir::LLVM::LLVMVoidType;
//===----------------------------------------------------------------------===//
// Descriptor types shared with the runtime.
//
// These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===//
// MaxRank that the refbackrt ABI lowering is capable of handling
// NOTE: This parameter must stay consistent with
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
static constexpr int kMaxRank = 6;
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 8));
}
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
return LLVMPointerType::get(IntegerType::get(context, 32));
}
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// ArgType
IntegerType::get(context, 32),
// ElementType
IntegerType::get(context, 32),
// Rank
IntegerType::get(context, 32),
// Extents
LLVMPointerType::get(IntegerType::get(context, 32)),
// IsStatic
// IntegerType::get(context, 32),
});
}
// Get the LLVM type for refbackrt::FuncDescriptor.
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// Name length.
IntegerType::get(context, 32),
// Name chars.
getInt8PointerType(context),
// Type-erased function pointer.
getInt8PointerType(context),
// Number of inputs.
IntegerType::get(context, 32),
// Number of outputs.
IntegerType::get(context, 32),
// Argument descriptors
LLVMPointerType::get(getInputDescriptorTy(context)),
// Result Descriptors
LLVMPointerType::get(getOutputDescriptorTy(context)),
});
}
// Get the LLVM type for refbackrt::ModuleDescriptor.
static LLVMStructType getModuleDescriptorTy(MLIRContext *context) {
return LLVMStructType::getLiteral(
context, {
// std::int32_t numFuncDescriptors;
IntegerType::get(context, 32),
// FuncDescriptor *functionDescriptors;
LLVMPointerType::get(getFuncDescriptorTy(context)),
});
}
//===----------------------------------------------------------------------===//
// Compiler runtime functions.
//===----------------------------------------------------------------------===//
namespace {
template <typename T>
class TrivialCompilerRuntimeLowering : public OpConversionPattern<T> {
public:
TrivialCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<T>(backingFunc.getContext()),
backingFunc(backingFunc) {}
LogicalResult
matchAndRewrite(T op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc, operands);
return success();
}
LLVM::LLVMFuncOp backingFunc;
};
} // namespace
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
OpBuilder &builder, Location loc) {
// TODO: Deduplicate strings.
std::string msgNulTerminated = msg.getValue().str();
msgNulTerminated.push_back('\0');
auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
msgNulTerminated.size());
OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(module.getBody());
// To get a unique symbol name, use a suffix derived from the current number
// of ops in the module.
// We can't use the SymbolTable's logic for this because the module
// transiently contains a `func` and `llvm.func` with the same name during
// conversion, preventing us from instantiating a SymbolTable.
std::string symbolName =
(Twine("__npcomp_string_") +
Twine(llvm::size(llvm::to_vector<6>(module.getOps<LLVM::GlobalOp>()))))
.str();
auto globalOp = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal, symbolName,
builder.getStringAttr(msgNulTerminated));
return globalOp;
}
namespace {
class AbortIfOpCompilerRuntimeLowering
: public OpConversionPattern<refbackrt::AbortIfOp> {
public:
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<refbackrt::AbortIfOp>(backingFunc.getContext()),
backingFunc(backingFunc) {}
LogicalResult
matchAndRewrite(refbackrt::AbortIfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
refbackrt::AbortIfOp::Adaptor adaptor(operands);
auto *context = op.getContext();
// Create the global string, take its address, and gep to get an `i8*`.
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
op.msgAttr(), rewriter, op.getLoc());
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
IntegerType::get(context, 32),
rewriter.getI32IntegerAttr(0));
auto msg =
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
msgArray, ValueRange({c0, c0}));
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
op, backingFunc, ValueRange({adaptor.pred(), msg}));
return success();
}
LLVM::LLVMFuncOp backingFunc;
};
} // namespace
// Create the LLVM runtime function backing the refbackrt op with name `name`
// and requiring `type`.
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, Type type,
OpBuilder &builder,
Location loc) {
assert(type.isa<LLVMFunctionType>());
std::string symbolName = (Twine("__npcomp_compiler_rt_") + name).str();
return builder.create<LLVM::LLVMFuncOp>(loc, symbolName, type,
LLVM::Linkage::External);
}
static void populateCompilerRuntimePatterns(ModuleOp module,
RewritePatternSet &patterns,
LLVMTypeConverter &typeConverter) {
auto *context = module.getContext();
OpBuilder builder(module.getBodyRegion());
{
auto abortIfFuncTy = LLVMFunctionType::get(
LLVMVoidType::get(context),
{IntegerType::get(context, 1), getInt8PointerType(context)},
/*isVarArg=*/false);
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
"abort_if", abortIfFuncTy, builder, module.getLoc());
patterns.add<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
}
}
//===----------------------------------------------------------------------===//
// Lowering for module metadata
//===----------------------------------------------------------------------===//
static LLVM::GlobalOp
createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
OpBuilder &builder, Location loc) {
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
for (auto funcMetadata : funcMetadatas) {
auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
funcMetadata.funcName().size());
std::string llvmSymbolName =
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
auto global = builder.create<LLVM::GlobalOp>(
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
globalsByName[funcMetadata.funcName()] = global;
// Create constants for the input / output shapes
if (funcMetadata.inputShapes().hasValue()) {
auto i32ArrayInputSymbolName =
(Twine("__npcomp_internal_constant_input_shapes_") +
funcMetadata.funcName())
.str();
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
auto inputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayInputSymbolName,
/*value=*/funcMetadata.inputShapes().getValue());
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
}
if (funcMetadata.outputShapes().hasValue()) {
auto i32ArrayOutputSymbolName =
(Twine("__npcomp_internal_constant_output_shapes_") +
funcMetadata.funcName())
.str();
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
auto outputI32ArrayTy =
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
i32ArrayOutputSymbolName,
/*value=*/funcMetadata.outputShapes().getValue());
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
}
}
auto updateDescriptor = [&](Value &descriptor, Value value,
std::initializer_list<int32_t> position) {
descriptor = builder.create<LLVM::InsertValueOp>(
loc, descriptor, value,
/*position=*/builder.getI32ArrayAttr(position));
};
auto updateDescriptorWithI32Attr =
[&](Value &descriptor, Attribute attr,
std::initializer_list<int32_t> position) {
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
updateDescriptor(descriptor, constant, position);
};
// Create global input descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmInputSymbolName =
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
auto inputDescriptorArrayTy =
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, inputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value inputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
for (int i = 0, e = funcMetadata.numInputs(); i < e; i++) {
// Arg Type
if (!funcMetadata.inputArgTypes().hasValue())
funcMetadata.emitError()
<< "numInputs > 0 but there are no inputArgTypes?";
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(inputDescriptorArray,
funcMetadata.inputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto inputShapesType =
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.inputRanks()->getValue(i);
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
// Shape
// Each shape array is derived by offseting of kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, inputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
inputDescriptorsByName[funcMetadata.funcName()] =
std::move(inputDescriptorArrayGlobal);
}
// Create global output descriptors
for (auto funcMetadata : funcMetadatas) {
std::string llvmOutputSymbolName =
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
auto outputDescriptorArrayTy =
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, outputDescriptorArrayTy, /*isConstant=*/true,
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
Value outputDescriptorArray =
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
for (int i = 0, e = funcMetadata.numOutputs(); i < e; i++) {
if (!funcMetadata.outputArgTypes().hasValue())
funcMetadata.emitError()
<< "numOutputs > 0 but there are no outputArgTypes?";
// Arg Type
updateDescriptorWithI32Attr(outputDescriptorArray,
funcMetadata.outputArgTypes()->getValue(i),
{i, 0});
// Element Type
updateDescriptorWithI32Attr(
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
{i, 1});
// Rank
// auto outputShapesType =
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
auto rank = funcMetadata.outputRanks()->getValue(i);
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
// Shapes
// Offset by kMaxRank * arg index
auto extentsArray = builder.create<LLVM::AddressOfOp>(
loc, outputShapesByName[funcMetadata.funcName()]);
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(i * kMaxRank));
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
loc, getInt32PointerType(builder.getContext()), extentsArray,
ValueRange({c0, cShapeOffset}));
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
}
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
outputDescriptorsByName[funcMetadata.funcName()] =
outputDescriptorArrayGlobal;
}
// This must match FuncDescriptor in the runtime.
auto funcDescriptorTy = getFuncDescriptorTy(builder.getContext());
auto funcDescriptorArrayTy =
LLVMArrayType::get(funcDescriptorTy, funcMetadatas.size());
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
"__npcomp_func_descriptors",
/*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
builder.getI32IntegerAttr(0));
// Build the initializer.
Value funcDescriptorArray =
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
auto funcMetadata = funcMetadataAndIndex.value();
int32_t index = funcMetadataAndIndex.index();
// Name length.
updateDescriptorWithI32Attr(
funcDescriptorArray,
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
// Name chars.
auto funcNameArray = builder.create<LLVM::AddressOfOp>(
loc, globalsByName[funcMetadata.funcName()]);
auto funcNamePtr = builder.create<LLVM::GEPOp>(
loc, getInt8PointerType(builder.getContext()), funcNameArray,
ValueRange({c0, c0}));
updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
// Function pointer.
//
// We create this reference to the original function (and use a dummy i8*
// type). We will fix this up after conversion to point at wrapper
// functions that satisfy the ABI requirements.
// The bitcast is required so that after conversion the inserted value is an
// i8* as expected by the descriptor struct.
auto funcAddress = builder.create<LLVM::AddressOfOp>(
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
loc, getInt8PointerType(builder.getContext()), funcAddress);
updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
// Number of inputs.
updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numInputsAttr(), {index, 3});
// Number of outputs.
updateDescriptorWithI32Attr(funcDescriptorArray,
funcMetadata.numOutputsAttr(), {index, 4});
// Input descriptors
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, inputDescriptorsByName[funcMetadata.funcName()]);
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
inputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
// Output descriptors
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
loc, outputDescriptorsByName[funcMetadata.funcName()]);
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
outputDescriptorsArrayAddress);
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
}
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
return funcDescriptorArrayGlobal;
}
LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
OpBuilder &builder, Location loc) {
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
auto moduleDescriptorTy = getModuleDescriptorTy(builder.getContext());
// TODO: Ideally this symbol name would somehow be related to the module
// name, if we could consistently assume we had one.
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
// is typically only mean for function pointers) will find this raw symbol.
auto moduleDescriptorGlobal = builder.create<LLVM::GlobalOp>(
loc, moduleDescriptorTy, /*isConstant=*/true, LLVM::Linkage::External,
"_mlir___npcomp_module_descriptor",
/*value=*/Attribute());
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(&moduleDescriptorGlobal.initializer());
Value moduleDescriptor =
builder.create<LLVM::UndefOp>(loc, moduleDescriptorTy);
auto updateDescriptor = [&](Value value,
std::initializer_list<int32_t> position) {
moduleDescriptor = builder.create<LLVM::InsertValueOp>(
loc, moduleDescriptor, value,
/*position=*/builder.getI32ArrayAttr(position));
};
updateDescriptor(builder.create<LLVM::ConstantOp>(
loc, llvmI32Ty,
builder.getI32IntegerAttr(funcDescriptorArray.getType()
.cast<LLVMArrayType>()
.getNumElements())),
{0});
auto funcDecriptorArrayAddress =
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
loc, LLVMPointerType::get(getFuncDescriptorTy(builder.getContext())),
funcDecriptorArrayAddress);
updateDescriptor(rawFuncDescriptorPtr, {1});
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
return moduleDescriptorGlobal;
}
namespace {
class LowerModuleMetadata
: public OpConversionPattern<refbackrt::ModuleMetadataOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(refbackrt::ModuleMetadataOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcMetadatas =
llvm::to_vector<6>(op.metadatas().getOps<refbackrt::FuncMetadataOp>());
auto funcDescriptorArray =
createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc());
auto moduleDescriptor =
createModuleDescriptor(funcDescriptorArray, rewriter, op.getLoc());
// TODO: create get module descriptor wrapper (or upgrade
// mlir::ExecutionEngine to allow raw symbol lookup)
(void)moduleDescriptor;
rewriter.eraseOp(op);
return success();
}
};
} // namespace
// Performs the calculation:
// ```
// ty *f(void **voidStarStar, int32_t i) {
// return reinterpret_cast<ty *>(voidStarStar[i]);
// }
// ```
static Value getTypedAddressFromVoidStarStar(Value voidStarStar, int32_t index,
Type ty, OpBuilder &builder,
Location loc) {
Value ci = builder.create<LLVM::ConstantOp>(
loc, IntegerType::get(builder.getContext(), 32),
builder.getI32IntegerAttr(index));
// Do `voidStarStar[i]` as a gep + load.
auto inputPtrAddr = builder.create<LLVM::GEPOp>(
loc, LLVMPointerType::get(getInt8PointerType(builder.getContext())),
voidStarStar, ValueRange(ci));
auto inputPtr = builder.create<LLVM::LoadOp>(loc, inputPtrAddr);
return builder.create<LLVM::BitcastOp>(loc, LLVMPointerType::get(ty),
inputPtr);
}
static SmallVector<Value, 6> loadCallArgs(Value inputsPtrPtr,
LLVMFunctionType funcTy,
OpBuilder &builder, Location loc) {
SmallVector<Value, 6> callArgs;
// For each void* in the void**, cast it to the right type and load it.
for (int i = 0, e = funcTy.getNumParams(); i < e; i++) {
auto paramTy = funcTy.getParamType(i);
auto addr =
getTypedAddressFromVoidStarStar(inputsPtrPtr, i, paramTy, builder, loc);
callArgs.push_back(builder.create<LLVM::LoadOp>(loc, addr));
}
return callArgs;
}
static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
LLVMTypeConverter converter(context);
// LLVMTypeConverter doesn't directly expose the struct type used to represent
// unranked memrefs on ABI boundaries. To get that type, we convert
// an unranked memref type and see what it produces.
//
// An unranked memref is just a size_t for the rank and an void* pointer to
// descriptor, so the choice of element type here is arbitrary -- it all
// converts to the same thing.
return converter.convertType(
UnrankedMemRefType::get(Float32Type::get(context),
/*memorySpace=*/0));
}
static Type getFloatType(MLIRContext *context) {
LLVMTypeConverter converter(context);
return converter.convertType(FloatType::getF32(context));
}
// Writes out the logical results of the wrapper function through the void**
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
// only supports a single return type (or void/no results), the logic here needs
// to be aware of the convention used in the Std to LLVM conversion to map
// multiple return types. The details of this are in the function
// packFunctionResults and its callers:
// https://github.com/llvm/llvm-project/blob/fad9cba8f58ba9979f390a49cf174ec9fcec29a6/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L282
static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
OpBuilder &builder, Location loc) {
// 0 results. Nothing to do.
if (callToWrapped.getNumResults() == 0)
return;
Value result = callToWrapped.getResult(0);
auto ty = result.getType();
// 1 logical result.
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
} else if (ty == getFloatType(ty.getContext())) {
Value addr =
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
builder.create<LLVM::StoreOp>(loc, result, addr);
return;
}
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
auto structType = ty.cast<LLVMStructType>();
// >=2 logical results. The convention linked above will create a struct
// wrapping.
for (int i = 0, e = structType.getBody().size(); i < e; i++) {
auto elementTy = structType.getBody()[i];
Value addr = getTypedAddressFromVoidStarStar(resultsPtrPtr, i, elementTy,
builder, loc);
int32_t i32I = i;
Value value = builder.create<LLVM::ExtractValueOp>(
loc, elementTy, result, builder.getI32ArrayAttr({i32I}));
builder.create<LLVM::StoreOp>(loc, value, addr);
}
}
// Construct a wrapper function.
// For an externally visible function f(T1, T2) -> T3, T4, we create a
// wrapper
// __refbackrt_wrapper_f(void **inputs, void ** outputs) {
// T3 t3;
// T4 t4;
// (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1]));
// *cast<T3*>(outputs[0]) = t3;
// *cast<T4*>(outputs[1]) = t4;
// }
// This is very similar to MLIR's "packed" convention, but supporting
// outputs.
// TODO: Extend MLIR's void** wrappers to have outputs in this way.
static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
auto *context = func.getContext();
LLVMFunctionType funcTy = func.getType();
auto voidStarTy = getInt8PointerType(context);
auto voidStarStarTy = LLVMPointerType::get(voidStarTy);
auto wrapperTy = LLVMFunctionType::get(LLVMVoidType::get(context),
{voidStarStarTy, voidStarStarTy},
/*isVarArg=*/false);
constexpr char kRefbackrtWrapperPrefix[] = "__refbackrt_wrapper_";
auto wrapperName = (Twine(kRefbackrtWrapperPrefix) + func.getName()).str();
OpBuilder moduleBuilder(func->getParentRegion());
LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>(
func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External);
// Create the function body.
Block &body = *wrapper.addEntryBlock();
auto builder = OpBuilder::atBlockBegin(&body);
auto callArgs =
loadCallArgs(body.getArgument(0), funcTy, builder, func.getLoc());
auto call = builder.create<LLVM::CallOp>(func.getLoc(), func, callArgs);
storeWrapperResults(call, body.getArgument(1), builder, func.getLoc());
builder.create<LLVM::ReturnOp>(func.getLoc(), ValueRange());
return wrapper;
}
namespace {
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<LLVM::LLVMDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
LLVMTypeConverter converter(context);
RewritePatternSet patterns(context);
LLVMConversionTarget target(*context);
populateCompilerRuntimePatterns(module, patterns, converter);
target.addLegalOp<ModuleOp>();
populateStdToLLVMConversionPatterns(converter, patterns);
populateMathToLLVMConversionPatterns(converter, patterns);
populateMemRefToLLVMConversionPatterns(converter, patterns);
patterns.add<LowerModuleMetadata>(context);
// TODO: Move these "std to std" legalizations to their own pass if we grow
// lots of these patterns.
populateExpandTanhPattern(patterns);
populateLinalgToLLVMConversionPatterns(converter, patterns);
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
return signalPassFailure();
}
// Rewrite llvm.mlir.addressof ops that reference the original exported
// functions from the module to instead refer to wrapper functions.
// These wrapper functions have a fixed ABI
// (`void f(void **inputs, void **results)`) which we can interface to from
// external code without dealing with platform-dependent
// register-level calling conventions. We embed enough information in the
// module metadata to make sure that calling code can e.g. preallocate
// enough outputs and with the right types to safely funnel through this
// convention.
module.walk([&](LLVM::AddressOfOp op) {
auto originalFunc =
module.lookupSymbol<LLVM::LLVMFuncOp>(op.global_name());
if (!originalFunc)
return;
auto wrapper = createWrapperFunc(originalFunc);
op.getResult().setType(LLVMPointerType::get(wrapper.getType()));
Builder builder(op.getContext());
op->setAttr("global_name",
SymbolRefAttr::get(builder.getContext(), wrapper.getName()));
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>> mlir::NPCOMP::createLowerToLLVMPass() {
return std::make_unique<LowerToLLVM>();
}

View File

@ -1,446 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "PassDetail.h"
#include "npcomp/RefBackend/RefBackend.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
// Since input/output shapes are not hyper-rectangular we specify
// a maximum rank for each input shape such that shapes are padded
// out to kMaxRank at the ABI boundary. That way we can represent
// shapes using a traditional DenseElementsAttr.
//
// NOTE: When changing this parameter, also change the same `kMaxRank`
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
// stays consistent.
static constexpr int kMaxRank = 6;
// Get the type used to represent MemRefType `type` on ABI boundaries.
// For convenience we do a cast to MemRefType internally.
static Type getABIMemrefType(Type type) {
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(),
/*memorySpace=*/0);
}
//===----------------------------------------------------------------------===//
// Creating module metadata.
//===----------------------------------------------------------------------===//
// Returns true if the function signature can be expressed with the refbackrt
// ABI.
static bool expressibleWithRefbackrtABI(FunctionType type) {
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
return llvm::all_of(
llvm::concat<const Type>(type.getInputs(), type.getResults()),
[](Type t) {
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
});
}
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
// Must stay aligned with CompilerDataStructures::ABIArgType enum
static uint32_t getIntReprForABIType(Type type) {
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
return 1;
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 2;
case 64:
return 3;
default:
assert(false && "Unsupported float bit width");
}
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
}
// assert(false && "couldn't get IntReprForABIType");
return -1;
}
// Must stay aligned with CompilerDataStructures::ABIElementType enum
static uint32_t getIntReprForABIElementType(Type type) {
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
auto elemTy = shapedTy.getElementType();
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
switch (floatTy.getWidth()) {
case 32:
return 1;
default:
assert(false && "Unsupported tensor element type");
}
}
}
return 0;
}
static SmallVector<int32_t, kMaxRank>
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
if (auto shapedType = type.dyn_cast<ShapedType>()) {
if (!shapedType.hasRank()) {
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
auto shape = shapedType.getShape();
auto shapeRank = shapedType.getRank();
if (shapeRank <= maxRank) {
SmallVector<int32_t, kMaxRank> extendedShape;
// Push back all the values of the shape
for (auto extentAndIndex : llvm::enumerate(shape)) {
auto extent = extentAndIndex.value();
auto index = extentAndIndex.index();
if (shapedType.isDynamic(index)) {
extendedShape.push_back(-1);
} else {
extendedShape.push_back(extent);
}
}
// Pad whatever is left so we have even vectors
auto padRank = maxRank - shapeRank;
for (int i = 0; i < padRank; i++)
extendedShape.push_back(0xDEAD'BEEF);
return extendedShape;
} else {
assert(false && "unsupported rank");
}
}
// Represent Scalar's as all 1's.
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
}
int32_t getRankForType(Type type) {
// Returns a rank of -1 if the tensor is unranked
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasRank() ? shapedType.getRank() : -1;
}
return 0;
}
uint32_t hasStaticShape(Type type) {
if (auto shapedType = type.dyn_cast<ShapedType>()) {
return shapedType.hasStaticShape() ? 1 : 0;
}
// Assume scalars and non-shaped type things are static
return 1;
}
static LogicalResult createModuleMetadata(ModuleOp module) {
auto moduleMetadata =
OpBuilder::atBlockBegin(module.getBody())
.create<refbackrt::ModuleMetadataOp>(module.getLoc());
moduleMetadata.metadatas().push_back(new Block);
Block &metadatas = moduleMetadata.metadatas().front();
OpBuilder::atBlockEnd(&metadatas)
.create<refbackrt::ModuleMetadataTerminatorOp>(module.getLoc());
SymbolTable symbolTable(module);
auto builder = OpBuilder::atBlockBegin(&metadatas);
for (auto func : module.getOps<FuncOp>()) {
if (symbolTable.getSymbolVisibility(func) !=
SymbolTable::Visibility::Public) {
continue;
}
// TODO: Add richer information here such as expected shapes and element
// types.
SmallVector<uint32_t, 6> inputABIArgTypes;
SmallVector<uint32_t, 6> inputABIElementTypes;
SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
SmallVector<uint32_t, 6> inputABIRanks;
// SmallVector<uint32_t, 6> inputIsStatic;
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
inputABIShapes.push_back(
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
inputABIRanks.push_back(getRankForType(inputArgType));
// inputIsStatic.push_back(hasStaticShape(inputArgType));
}
SmallVector<uint32_t, 6> outputABIArgTypes;
SmallVector<uint32_t, 6> outputABIElementTypes;
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
SmallVector<uint32_t, 6> outputABIRanks;
SmallVector<uint32_t, 6> outputIsStatic;
for (const auto &outputArgType : func.getCallableResults()) {
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
outputABIElementTypes.push_back(
getIntReprForABIElementType(outputArgType));
outputABIShapes.push_back(
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
outputABIRanks.push_back(getRankForType(outputArgType));
// outputIsStatic.push_back(hasStaticShape(outputArgType));
}
auto i32Type = builder.getIntegerType(32);
auto inputABIDataType =
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
auto inputABIElementType =
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
auto inputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
kMaxRank},
i32Type);
auto inputABIRanksType =
RankedTensorType::get(inputABIRanks.size(), i32Type);
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
// i32Type);
auto outputABIDataType =
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
auto outputABIElementType =
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
auto outputABIShapesType = RankedTensorType::get(
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
kMaxRank},
i32Type);
auto outputABIRanksType =
RankedTensorType::get(outputABIRanks.size(), i32Type);
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
// i32Type);
// TODO(brycearden): I'm sure there's a cleaner way to do this
auto flattenABIShapes =
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
SmallVector<int32_t, 32> ret;
for (auto &shape : shapes)
for (auto &dim : shape)
ret.push_back(dim);
return ret;
};
SmallVector<NamedAttribute, 16> namedAttrs;
// Add attributes that are valid for every func (funcName, numInputs,
// numOutputs)
namedAttrs.push_back(std::make_pair(
Identifier::get("funcName", module.getContext()),
SymbolRefAttr::get(builder.getContext(), func.getName())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numInputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumArguments())));
namedAttrs.push_back(
std::make_pair(Identifier::get("numOutputs", module.getContext()),
builder.getI32IntegerAttr(func.getNumResults())));
if (inputABIArgTypes.size()) {
// Only add input information if there are inputs
namedAttrs.push_back(std::make_pair(
Identifier::get("inputArgTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIDataType,
llvm::makeArrayRef(inputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputElementTypes", func.getContext()),
DenseIntElementsAttr::get(inputABIElementType,
llvm::makeArrayRef(inputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("inputRanks", func.getContext()),
DenseIntElementsAttr::get(inputABIRanksType,
llvm::makeArrayRef(inputABIRanks))));
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
namedAttrs.push_back(std::make_pair(
Identifier::get("inputShapes", func.getContext()),
DenseIntElementsAttr::get(
inputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
}
if (outputABIArgTypes.size()) {
// Only add output information if there are outptus
namedAttrs.push_back(std::make_pair(
Identifier::get("outputArgTypes", func.getContext()),
DenseIntElementsAttr::get(outputABIDataType,
llvm::makeArrayRef(outputABIArgTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputElementTypes", func.getContext()),
DenseIntElementsAttr::get(
outputABIElementType,
llvm::makeArrayRef(outputABIElementTypes))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputRanks", func.getContext()),
DenseIntElementsAttr::get(outputABIRanksType,
llvm::makeArrayRef(outputABIRanks))));
namedAttrs.push_back(std::make_pair(
Identifier::get("outputShapes", func.getContext()),
DenseIntElementsAttr::get(
outputABIShapesType,
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
}
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
ArrayRef<Value>{}, namedAttrs);
if (!expressibleWithRefbackrtABI(func.getType()))
return func.emitError() << "func not expressible with refbackrt ABI";
}
return success();
}
//===----------------------------------------------------------------------===//
// Dialect conversion.
//===----------------------------------------------------------------------===//
namespace {
class LowerAssertOp : public OpConversionPattern<AssertOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
AssertOp::Adaptor adaptor(operands);
// The refbackrt runtime function aborts if the argument is true, rather
// than when it is false as an `assert` does. So negate the predicate (by
// xor'ing with 1).
auto c1 = rewriter.create<ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
APInt(/*numBits=*/1, /*val=*/1)));
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
rewriter.replaceOpWithNewOp<refbackrt::AbortIfOp>(op, assertFailed,
op.msgAttr());
return success();
}
};
} // namespace
namespace {
// At ABI boundaries, convert all memrefs to unranked memrefs so that they have
// a fixed ABI.
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType type = op.getType();
TypeConverter::SignatureConversion entryConversion(type.getNumInputs());
if (failed(typeConverter->convertSignatureArgs(type.getInputs(),
entryConversion)))
return rewriter.notifyMatchFailure(op, "could not convert inputs");
SmallVector<Type, 1> newResultTypes;
if (failed(typeConverter->convertTypes(type.getResults(), newResultTypes)))
return rewriter.notifyMatchFailure(op, "could not convert outputs");
rewriter.updateRootInPlace(op, [&] {
// Update the function type.
op.setType(FunctionType::get(op.getContext(),
entryConversion.getConvertedTypes(),
newResultTypes));
// Rewrite the entry block.
Block &oldEntry = op.getBody().front();
Block &newEntry =
*rewriter.applySignatureConversion(&op.getBody(), entryConversion);
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&newEntry);
BlockArgument newArg, oldArg;
for (auto newAndOldArg :
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
std::tie(newArg, oldArg) = newAndOldArg;
auto memref = rewriter.create<memref::CastOp>(op.getLoc(), newArg,
oldArg.getType());
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
}
});
return success();
}
};
} // namespace
namespace {
// At the return ABI boundaries, convert to the ABI type.
// This pattern is needed to trigger the type conversion mechanics to do a
// target materialization.
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
return success();
}
};
} // namespace
static LogicalResult doDialectConversion(ModuleOp module) {
auto *context = module.getContext();
TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](MemRefType type) { return getABIMemrefType(type); });
typeConverter.addTargetMaterialization(
[](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs,
Location loc) -> Value {
assert(inputs.size() == 1);
return builder.create<memref::CastOp>(
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
});
RewritePatternSet patterns(context);
ConversionTarget target(*context);
target.addLegalDialect<refbackrt::RefbackrtDialect>();
target.addLegalDialect<StandardOpsDialect>();
target.addLegalDialect<memref::MemRefDialect>();
patterns.add<FuncOpSignatureConversion>(typeConverter, context);
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
patterns.add<RewriteReturnOp>(typeConverter, context);
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return typeConverter.isLegal(op); });
patterns.add<LowerAssertOp>(context);
target.addIllegalOp<AssertOp>();
return applyPartialConversion(module, target, std::move(patterns));
}
namespace {
// This pass lowers the public ABI of the module to the primitives exposed by
// the refbackrt dialect.
class LowerToRefbackrtABI
: public LowerToRefbackrtABIBase<LowerToRefbackrtABI> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<refbackrt::RefbackrtDialect, memref::MemRefDialect>();
}
void runOnOperation() override {
ModuleOp module = getOperation();
// Before we lower anything, capture any needed metadata about the argument
// lists that will be needed for safely invoking the raw runtime functions
// later. (for example, number of expected arguments/results, types,
// etc.)
if (failed(createModuleMetadata(module)))
return signalPassFailure();
// Now do the actual conversion / lowering.
if (failed(doDialectConversion(module)))
return signalPassFailure();
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createLowerToRefbackrtABIPass() {
return std::make_unique<LowerToRefbackrtABI>();
}

View File

@ -1,25 +0,0 @@
//===- PassDetail.h - RefBackend Pass class details -------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef REFBACKEND_PASSDETAIL_H
#define REFBACKEND_PASSDETAIL_H
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace NPCOMP {
#define GEN_PASS_CLASSES
#include "npcomp/RefBackend/Passes.h.inc"
} // namespace NPCOMP
} // end namespace mlir
#endif // REFBACKEND_PASSDETAIL_H

View File

@ -1,234 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This is the base file for npcomp's "reference backend".
//
// The input to this backend is a layer that consists of linalg-on-tensors
// together with std scalar ops and control flow.
//
// The output of this backend is LLVM IR suitable for JITing.
//
// We expect that other backends will appear that have a similar kind of
// interface. IREE already uses this layering.
//
//===----------------------------------------------------------------------===//
#include "npcomp/RefBackend/RefBackend.h"
#include "PassDetail.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Passes.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "npcomp/Backend/Common/Passes.h"
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// Pass registration
//===----------------------------------------------------------------------===//
namespace {
#define GEN_PASS_REGISTRATION
#include "npcomp/RefBackend/Passes.h.inc"
} // end namespace
void mlir::NPCOMP::registerRefBackendPasses() {
::registerPasses();
mlir::PassPipelineRegistration<RefBackendLoweringPipelineOptions>(
"refback-lowering-pipeline", "RefBackend lowering pipeline.",
mlir::NPCOMP::createRefBackendLoweringPipeline);
}
//===----------------------------------------------------------------------===//
// LowerAllocMemRefOps
//===----------------------------------------------------------------------===//
namespace {
class LowerAllocMemRefOp : public OpRewritePattern<refback::AllocMemRefOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(refback::AllocMemRefOp op,
PatternRewriter &rewriter) const override {
auto memrefType = op.getType().cast<MemRefType>();
auto shape = op.getOperand();
// std.alloc only accepts the dynamic extents as operands, so only
// collect those.
SmallVector<Value, 6> dynamicExtents;
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
if (memrefType.isDynamicDim(i)) {
auto ci = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
auto extent = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape,
ValueRange({ci}));
dynamicExtents.push_back(extent);
}
}
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
dynamicExtents);
return success();
}
};
} // namespace
namespace {
class LowerAllocMemRefOps
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<LowerAllocMemRefOp>(context);
ConversionTarget target(*context);
target.addIllegalOp<refback::AllocMemRefOp>();
target.addLegalOp<tensor::ExtractOp>();
target.addLegalOp<memref::AllocOp>();
target.addLegalOp<ConstantOp>();
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
return std::make_unique<LowerAllocMemRefOps>();
}
//===----------------------------------------------------------------------===//
// createRefBackendLoweringPipeline
//===----------------------------------------------------------------------===//
void mlir::NPCOMP::createRefBackendLoweringPipeline(
OpPassManager &pm, const RefBackendLoweringPipelineOptions &options) {
// Delete dead lists. RefBackend doesn't support lists, but in some cases
// we can get by if all the lists are dead.
pm.addNestedPass<FuncOp>(
NPCOMP::CommonBackend::createDeleteDeadIREEListsPass());
// Convert all elementwise ops to linalg.
//
// Considering correctness, this lets us reuse the linalg bufferization, which
// applies uniformly to all linalg structured ops.
//
// Also, converting to linalg herevopens up a lot of optimization
// opportunities.
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
if (options.optimize) {
pm.addNestedPass<FuncOp>(createLinalgElementwiseOpFusionPass());
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// Lower shape constraints before we enter tensor->memref conversion.
// That is, we expand shape.cstr_* ops to eager error handling code.
pm.addNestedPass<FuncOp>(createConvertShapeConstraintsPass());
// Run shape canonicalizations. In particular, this erases shape.assuming,
// now that we have converted shape constraints.
// TODO: Don't canonicalize everything.
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
// Lower shape ops to std.
pm.addPass(createConvertShapeToStandardPass());
// --------------------------------------------------------------------------
// Lower the `tensor` type to `memref`.
// --------------------------------------------------------------------------
// We make a conscious effort here to do this as a sequence of separate passes
// rather than a single mega dialect conversion pass.
//
// This means that intermediate steps have source/target materializations
// (memref.tensor_load / memref.buffer_cast) in the IR.
// Run tensor constant bufferization.
// This pass has to run on a module op, and so does the final
// FuncBufferizePass. But everything else can run in parallel on functions,
// so we try to bracket the entire bufferization pipeline with the module
// passes to allow maximum parallelism.
pm.addPass(createTensorConstantBufferizePass());
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
// We need to resolve this to std.alloc which takes individual extents.
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
pm.addNestedPass<FuncOp>(createStdBufferizePass());
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
pm.addPass(createFuncBufferizePass());
pm.addNestedPass<FuncOp>(createFinalizingBufferizePass());
// TODO: Do buffer deallocation. We should be able to just drop in the
// upstream pass?
// At this point, we have lots of loose stuff floating around from lowering,
// so it's a good time to do some general cleanups.
if (options.optimize) {
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// --------------------------------------------------------------------------
// Preparation for converting to an LLVM module.
// --------------------------------------------------------------------------
// Now, we begin the process of lowering to LLVM's level of abstraction
// (after which LLVM will take over lowering to machine code).
// Lower linalg ops to loops.
// TODO: Do some linalg optimizations like tiling here.
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
// Run a some cleanups.
if (options.optimize) {
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
// --------------------------------------------------------------------------
// Final conversion to an LLVM module.
// --------------------------------------------------------------------------
// Convert affine to std control flow in preparation for going to LLVM.
pm.addNestedPass<FuncOp>(createLowerAffinePass());
// Convert scf to std control flow in preparation for going to LLVM.
pm.addNestedPass<FuncOp>(createLowerToCFGPass());
// Convert functions signatures and other constructs that interface with the
// runtime to the `refbackrt` dialect.
pm.addPass(createLowerToRefbackrtABIPass());
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
// which reuses the upstream patterns and gives us a place to add our own
// patterns for our own custom ops like the refbackrt ops.
pm.addPass(createLowerToLLVMPass());
// Although LLVM will clean everything up eventually, for the sake of IR
// clarity while still in MLIR, run some cleanups.
if (options.optimize) {
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
pm.addNestedPass<FuncOp>(createCSEPass());
}
}

View File

@ -1,36 +0,0 @@
# Appease LLVM check that all sources are covered by a target.
# This doesn't seem to play well with having multiple targets
# in a single directory.
set(LLVM_OPTIONAL_SOURCES
Runtime.cpp
CompilerRuntime.cpp
)
# The library that users link against, defining basic interactions with an
# refbackrt module and the relevant data structures.
add_npcomp_library(NPCOMPRuntime
Runtime.cpp
)
mlir_check_all_link_libraries(NPCOMPRuntime)
# The library that defines the symbols that the compiler emits references
# to.
# Note: is uses some of the same facilities that the user API depends on,
# we use a linker script to ensure that the shared library only exposes the
# symbols the compiler needs.
#
# This is currently done as a shared library to make it suitable for being
# loaded by mlir::ExecutionEngine. In e.g. an embedded scenario, we would
# need to create a static library and link that into the binary.
add_npcomp_library(NPCOMPCompilerRuntimeShlib
SHARED
CompilerRuntime.cpp
EXCLUDE_FROM_LIBNPCOMP
)
target_link_libraries(NPCOMPCompilerRuntimeShlib PRIVATE NPCOMPRuntime)
if (UNIX AND NOT APPLE)
set_target_properties(NPCOMPCompilerRuntimeShlib PROPERTIES LINK_FLAGS
"-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/unix_version.script")
endif()

View File

@ -1,87 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains data structures (which we typically call "descriptors")
// that are emitted by the compiler and must be kept in sync with the compiler
// code that creates them in LowerToLLVM.cpp.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
#include <cstdint>
namespace refbackrt {
// All arguments are packed into this type-erased form for being invoked. See
// LowerToLLVM.cpp for more details.
typedef void ABIFunc(void **, void **);
enum class ABIArgType : std::uint32_t {
kNone = 0,
kMemref,
kF32,
kF64,
};
enum class ABIElementType : std::uint32_t {
kNone = 0,
kF32,
};
struct InputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
// std::int32_t isStatic;
};
struct OutputDescriptor {
ABIArgType abiType;
ABIElementType elementType;
std::int32_t rank;
std::int32_t* extents;
// TODO(brycearden): Change to bool at ABI boundary
//std::int32_t isStatic;
};
struct FuncDescriptor {
// The length of the function name.
std::int32_t nameLen;
// The name of the function, to allow lookup.
const char *name;
// This is a raw function pointer to the function's entry point as
// emitted by the compiler.
ABIFunc *functionPtr;
// The number of inputs to the function.
std::int32_t numInputs;
// The number of outputs of the function.
std::int32_t numOutputs;
// TODO: Add shape checking to arg / result descriptor(s)
InputDescriptor *inputDescriptors;
OutputDescriptor *outputDescriptors;
};
// The top-level entry point of the module metadata emitted by the
// compiler. Unlike all the other descriptors here, external code does handle
// this type (albeit through an opaque pointer).
struct ModuleDescriptor {
std::int32_t numFuncDescriptors;
FuncDescriptor *functionDescriptors;
};
} // namespace refbackrt
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H

View File

@ -1,33 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Symbols referenced only by the compiler and which will be compiled into a
// shared object that a JIT can load to provide those symbols.
//
//===----------------------------------------------------------------------===//
#include <array>
#include <cstdlib>
#include <iostream>
#include "CompilerDataStructures.h"
#include "npcomp/RefBackend/Runtime/UserAPI.h"
using namespace refbackrt;
extern "C" {
__attribute__((visibility("default"))) void
__npcomp_compiler_rt_abort_if(bool b, const char *msg);
}
void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
if (b) {
std::fprintf(stderr, "NPCOMP: aborting: %s\n", msg);
std::exit(1);
}
}

View File

@ -1,519 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/RefBackend/Runtime/UserAPI.h"
#include "llvm/Support/ErrorHandling.h"
#include <array>
#include <cassert>
#include <cstdint>
#include <cstring>
#include "CompilerDataStructures.h"
using namespace refbackrt;
//===----------------------------------------------------------------------===//
// Memref descriptors for interacting with MLIR codegenerated code.
//===----------------------------------------------------------------------===//
namespace {
// These definitions are based on the ones in
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
// sync.
//
// Those definitions are flawed though because they are overly templated.
struct MemrefDescriptor {
void *allocatedPtr;
void *dataPtr;
std::int64_t offset;
// Tail-allocated int64_t sizes followed by strides.
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail, assumedRank);
}
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
}
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
// default striding.
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
// Returns the number of elements in this MemrefDescriptor, assuming this
// descriptor has rank `assumedRank`.
std::int32_t getNumElements(int assumedRank) {
if (assumedRank == 0)
return 1;
return getSizes(assumedRank)[0] * getStrides(assumedRank)[0];
}
};
} // namespace
namespace {
struct UnrankedMemref {
int64_t rank;
MemrefDescriptor *descriptor;
};
} // namespace
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
void *data) {
auto rank = extents.size();
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
descriptor->allocatedPtr = data;
descriptor->dataPtr = data;
descriptor->offset = 0;
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
// calculating the strides for a standard dense layout.
std::int64_t stride = 1;
for (int i = 0, e = rank; i < e; i++) {
auto revIdx = e - i - 1;
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
descriptor->getStrides(rank)[revIdx] = stride;
stride *= extents[revIdx];
}
return descriptor;
}
static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) {
auto byteSize = tensor->getDataByteSize();
void *data = std::malloc(byteSize);
std::memcpy(data, tensor->getData(), byteSize);
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
return UnrankedMemref{tensor->getRank(), descriptor};
}
static Tensor *convertUnrankedMemrefToRefbackrtTensor(
std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) {
// Launder from std::int64_t to std::int32_t.
auto extents64 = descriptor->getSizes(rank);
constexpr int kMaxRank = 20;
std::array<std::int32_t, kMaxRank> extents32Buf;
for (int i = 0, e = extents64.size(); i < e; i++)
extents32Buf[i] = extents64[i];
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, descriptor->dataPtr);
}
//===----------------------------------------------------------------------===//
// Tensor
//===----------------------------------------------------------------------===//
static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
std::int32_t ret = 1;
for (int i = 0, e = extents.size(); i < e; i++) {
ret *= extents[i];
}
return ret;
}
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
switch (type) {
case ElementType::NONE:
return 0;
case ElementType::F32:
return 4;
}
llvm_unreachable("unsupported dtype");
}
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
switch (type) {
case ElementType::NONE:
return "NONE";
case ElementType::F32:
return "F32";
}
llvm_unreachable("unsupported element type string");
}
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
switch (type) {
case ArgType::kNone:
return "kNone";
case ArgType::kTensor:
return "kTensor";
case ArgType::kF32:
return "kF32";
case ArgType::kF64:
return "kF64";
}
llvm_unreachable("unsupported arg type string");
}
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
void *data) {
return Ref<Tensor>(createRaw(extents, type, data));
}
Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
void *data) {
auto *tensor = static_cast<Tensor *>(
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
tensor->refCount = 0;
tensor->elementType = type;
tensor->rank = extents.size();
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
// TODO: Align the buffer.
tensor->allocatedPtr = std::malloc(byteSize);
tensor->data = tensor->allocatedPtr;
std::memcpy(tensor->data, data, byteSize);
for (int i = 0, e = extents.size(); i < e; i++)
tensor->getMutableExtents()[i] = extents[i];
return tensor;
}
std::int32_t Tensor::getDataByteSize() const {
return getElementTypeByteSize(getElementType()) * totalElements(getExtents());
}
//===----------------------------------------------------------------------===//
// Module metadata descriptors.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
// Module operations.
//===----------------------------------------------------------------------===//
template <typename T> static void *ToVoidPtr(T *ptr) {
return const_cast<void *>(static_cast<const void *>(ptr));
}
static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
StringRef name) {
for (int i = 0, e = moduleDescriptor->numFuncDescriptors; i < e; i++) {
auto &functionDescriptor = moduleDescriptor->functionDescriptors[i];
if (StringRef(functionDescriptor.name, functionDescriptor.nameLen) ==
name) {
return &functionDescriptor;
}
}
return nullptr;
}
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
StringRef functionName, ArrayRef<RtValue> inputs,
MutableArrayRef<RtValue> outputs) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
// We haven't committed to using "vector" in this runtime code, so use
// a fixed-sized array.
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
std::array<void *, kMaxArity * 2> packedInputs;
std::array<void *, kMaxArity> packedOutputs;
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
// TODO: Avoid the deep copy. It makes the later lifetime management code
// more complex though (and maybe impossible given the current abstractions).
//
// Create a type-erased list of "packed inputs" to pass to the
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
// one LLVM/C ABI argument to the underlying function.
//
// The ABI lowering on StandardToLLVM conversion side will
// "explode" the unranked memref descriptors on the underlying function
// into separate arguments for the rank and pointer-to-descriptor.
for (int i = 0, e = inputs.size(); i < e; i++) {
auto idx = 2 * i;
if (inputs[i].isTensor()) {
inputUnrankedMemrefs[i] =
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
} else if (inputs[i].isScalar()) {
packedInputs[idx] = ToVoidPtr(&inputs[i]);
} else {
assert(false && "unsupported input RtValue type");
}
}
// Create a type-erased list of "packed output" to pass to the
// LLVM/C ABI wrapper function.
//
// Due to how StandardToLLVM lowering works, each packedOutput pointer
// corresponds to a single UnrankedMemref (not "exploded").
for (int i = 0, e = outputs.size(); i < e; i++) {
if (outputs[i].isTensor()) {
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
} else if (outputs[i].isScalar()) {
packedOutputs[i] = ToVoidPtr(&outputs[i]);
}
}
// Actually invoke the function!
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
// Copy out the result data into refbackrt::Tensor's.
// TODO: Avoid needing to make a deep copy.
for (int i = 0, e = outputs.size(); i < e; i++) {
// TODO: Have compiler emit the element type in the metadata.
if (outputs[i].isTensor()) {
auto elementType = ElementType::F32;
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
elementType);
outputs[i] = RtValue(Ref<Tensor>(tensor));
} else if (outputs[i].isFloat()) {
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
}
}
// Now, we just need to free all the UnrankedMemref's that we created.
// This is complicated by the fact that multiple input/output UnrankedMemref's
// can end up with the same backing buffer (`allocatedPtr`), and we need
// to avoid double-freeing.
// Output buffers might alias any other input or output buffer.
// Input buffers are guaranteed to not alias each other.
// Free the output buffers.
for (int i = 0, e = outputs.size(); i < e; i++) {
if (outputs[i].isRef()) {
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
// Multiple returned memrefs can point into the same underlying
// malloc allocation. Do a linear scan to see if any of the previously
// deallocated buffers already freed this pointer.
bool bufferNeedsFreeing = true;
for (int j = 0; j < i; j++) {
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
}
// Free the input buffers.
for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
bool bufferNeedsFreeing = true;
for (int j = 0, je = outputs.size(); j < je; j++) {
if (!outputs[j].isRef())
continue;
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
bufferNeedsFreeing = false;
}
// HACK: The returned memref can point into statically allocated memory that
// we can't pass to `free`, such as the result of lowering a tensor-valued
// `std.constant` to `std.global_memref`. The LLVM lowering of
// std.global_memref sets the allocated pointer to the magic value
// 0xDEADBEEF, which we sniff for here. This is yet another strong signal
// that memref is really not the right abstraction for ABI's.
if (reinterpret_cast<std::intptr_t>(allocatedPtr) == 0xDEADBEEF)
bufferNeedsFreeing = false;
if (!bufferNeedsFreeing)
std::free(allocatedPtr);
}
// Free the output descriptors.
for (int i = 0, e = outputs.size(); i < e; i++) {
if (!outputs[i].isRef())
continue;
// The LLVM lowering guarantees that each returned unranked memref
// descriptor is separately malloc'ed, so no need to do anything special
// like we had to do for the allocatedPtr's.
std::free(outputUnrankedMemrefs[i].descriptor);
}
// Free the input descriptors.
for (int i = 0, e = inputs.size(); i < e; i++) {
if (!inputs[i].isRef())
continue;
std::free(inputUnrankedMemrefs[i].descriptor);
}
}
static InputArgInfo
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
InputArgInfo ret;
// Set arg / element types accordingly
switch (inputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
}
// Extract shape information
ret.rank = inputDescriptor.rank;
for (int i = 0; i < inputDescriptor.rank; i++) {
ret.extents[i] = inputDescriptor.extents[i];
}
return ret;
}
static OutputArgInfo
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
OutputArgInfo ret;
// Set arg / element types accordingly
switch (outputDescriptor.abiType) {
case ABIArgType::kNone:
ret.argType = ArgType::kNone;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kMemref:
ret.argType = ArgType::kTensor;
ret.elementType = ElementType::F32;
break;
case ABIArgType::kF32:
ret.argType = ArgType::kF32;
ret.elementType = ElementType::NONE;
break;
case ABIArgType::kF64:
ret.argType = ArgType::kF64;
ret.elementType = ElementType::NONE;
break;
}
// Extract shape information
ret.rank = outputDescriptor.rank;
for (int i = 0; i < outputDescriptor.rank; i++) {
ret.extents[i] = outputDescriptor.extents[i];
}
return ret;
}
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName,
FunctionMetadata &outMetadata) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
if (!descriptor)
return failure();
outMetadata.numInputs = descriptor->numInputs;
outMetadata.numOutputs = descriptor->numOutputs;
for (int i = 0; i < descriptor->numInputs; i++) {
outMetadata.inputArgInfos[i] =
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
}
for (int i = 0; i < descriptor->numOutputs; i++) {
outMetadata.outputArgInfos[i] =
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
}
return success();
}
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
const InputArgInfo &info) {
if (value.isTensor()) {
auto refTensor = value.toTensor();
// Don't bother checking shapes for unranked tensors
if (info.rank < 0)
return success();
if (refTensor->getRank() != info.rank)
return failure();
auto tensorExtents = refTensor->getExtents();
for (int i = 0; i < info.rank; i++) {
// If a dimension is dynamic, it is encoded as extent = -1
// and we should skip checking over that dimension
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
return failure();
}
}
return success();
}
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
const InputArgInfo &info) {
// Generic checks based on argType(s)
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
(value.isFloat() && info.argType != ArgType::kF32))
return failure();
if (value.isRef()) {
// Will need special error checking for ref-counted types
// Currently only f32 tensors are supported
if (value.isTensor()) {
auto refTensor = value.toTensor();
if (refTensor->getElementType() != ElementType::F32)
return failure();
} else {
assert(false && "Unsupported input type checking for Ref type");
}
}
return success();
}
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
constexpr int32_t kDynamicConstantShape = 100;
switch (info.argType) {
case ArgType::kTensor: {
// HACK: for dynamic dims the shape will be negative, so for now we are
// just going to create a tensor of size kDynamicConstantShape
std::array<int32_t, kMaxRank> tensorShape;
for (int i = 0; i < info.rank; i++) {
tensorShape[i] =
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
}
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
int numel = 1;
for (int i = 0; i < info.rank; i++)
numel *= shape[i];
void *data;
switch (info.elementType) {
case ElementType::F32: {
auto byteSize = numel * sizeof(float);
data = static_cast<void *>(malloc(byteSize));
assert(data && "could not allocate tensor");
memset(data, 0, byteSize);
return RtValue(Tensor::create(shape, ElementType::F32, data));
break;
}
default: {
assert(false && "unknown output tensor type");
return RtValue();
}
}
// The Tensor::create function will malloc and memcpy the data
// into the Tensor object, so after we need to free our
// temporary data buffer
assert(data && "data ptr must exist");
auto refTensor = Tensor::create(shape, ElementType::F32, data);
free(data);
return RtValue(refTensor);
}
case ArgType::kF32: {
return RtValue(-20.0f);
}
default: {
assert(false && "Don't know how to handle this artType");
return RtValue();
}
}
}

View File

@ -1,5 +0,0 @@
{
global: __npcomp_compiler_rt_*;
local: *;
};

View File

@ -3,25 +3,6 @@ include(MLIRDetectPythonEnv)
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=npcomp.")
################################################################################
# Resources that must be packaged into the python tree
################################################################################
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/backend/refjit_resources")
add_custom_target(NPCOMPPythonResources ALL)
add_custom_command(
TARGET NPCOMPPythonResources
COMMAND ${CMAKE_COMMAND} -E copy
# TODO: Make the runtime library work for windows.
# TODO: Use $<TARGET-FILE:> for this.
${CMAKE_BINARY_DIR}/lib/libNPCOMPCompilerRuntimeShlib${CMAKE_SHARED_LIBRARY_SUFFIX}
${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/npcomp/compiler/generic/backend/libNPCOMPCompilerRuntimeShlib${CMAKE_SHARED_LIBRARY_SUFFIX}
)
add_dependencies(NPCOMPPythonResources
NPCOMPCompilerRuntimeShlib
)
################################################################################
# Declare sources
################################################################################
@ -50,17 +31,11 @@ declare_mlir_python_sources(NPCOMPPythonCAPIHeaderSources
# Extensions
################################################################################
set(_addl_extension_sources)
if(NPCOMP_ENABLE_REFJIT)
list(APPEND _addl_extension_sources "${CMAKE_CURRENT_SOURCE_DIR}/RefJITBackend.cpp")
endif()
declare_mlir_python_extension(NPCOMPPythonExtensions.Core
MODULE_NAME _npcomp
ADD_TO_PARENT NPCOMPPythonExtensions
SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/NpcompModule.cpp
${_addl_extension_sources}
EMBED_CAPI_LINK_LIBS
NPCOMPCAPI
PRIVATE_LINK_LIBS
@ -109,9 +84,6 @@ add_mlir_python_modules(NPCOMPPythonModules
COMMON_CAPI_LINK_LIBS
NPCOMPPythonCAPI
)
add_dependencies(NPCOMPPythonModules
NPCOMPPythonResources
)
################################################################################
# Torch support libraries.

View File

@ -38,10 +38,4 @@ PYBIND11_MODULE(_npcomp, m) {
// Optional backend modules.
auto backend_m = m.def_submodule("backend", "Backend support");
(void)backend_m;
#ifdef NPCOMP_ENABLE_REFJIT
auto refjit_m =
backend_m.def_submodule("refjit", "Reference CPU Jit Backend");
::npcomp::python::defineBackendRefJitModule(refjit_m);
#endif
}

View File

@ -1,164 +0,0 @@
//===- PythonModule.cpp - RefJIT python bindings --------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "./NpcompModule.h"
#include <cstdlib>
#include "pybind11/numpy.h"
#include "npcomp-c/RefJITBackend.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
using llvm::SmallVector;
using llvm::StringRef;
using llvm::Twine;
static NpcompRefJitElementType
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
if (format == "f")
return NPCOMP_REFJIT_F32;
std::string message("unsupported buffer format: ");
message.append(format);
throw py::raiseValueError(message);
}
namespace {
struct PyRefJitModule {
PyRefJitModule(NpcompRefJitModule instance) : instance(instance) {}
~PyRefJitModule() {
if (instance.ptr)
npcompRefJitModuleDestroy(instance);
}
PyRefJitModule(const PyRefJitModule &) = delete;
void operator=(const PyRefJitModule &) = delete;
PyRefJitModule(PyRefJitModule &&other) : instance(other.instance) {
other.instance.ptr = nullptr;
}
NpcompRefJitModule instance = {nullptr};
};
struct PyRefJitValueList {
PyRefJitValueList(NpcompRefJitValueList instance) : instance(instance) {}
~PyRefJitValueList() {
if (instance.ptr)
npcompRefJitValueListDestroy(instance);
}
PyRefJitValueList(const PyRefJitValueList &) = delete;
void operator=(const PyRefJitValueList &) = delete;
PyRefJitValueList(PyRefJitValueList &&other) : instance(other.instance) {
other.instance.ptr = nullptr;
}
NpcompRefJitValueList instance = {nullptr};
};
} // namespace
void npcomp::python::defineBackendRefJitModule(py::module &m) {
m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) {
npcompRefJitBuildBackendCompilationPipeline(capiPm, /*optimize=*/true);
});
py::class_<PyRefJitValueList>(m, "ValueList");
py::class_<PyRefJitModule>(m, "JITModule")
.def_static(
"from_compiled_module",
[](MlirModule capiModule,
std::vector<std::string> pySharedLibs) -> PyRefJitModule {
SmallVector<MlirStringRef, 4> sharedLibs;
for (auto &s : pySharedLibs)
sharedLibs.push_back(MlirStringRef{s.data(), s.size()});
char *errorMessageCstr;
NpcompRefJitModule m =
npcompRefJitModuleCreate(capiModule, &sharedLibs[0],
sharedLibs.size(), &errorMessageCstr);
if (npcompRefJitModuleIsNull(m)) {
std::string errorMessage(errorMessageCstr);
std::free(errorMessageCstr);
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
}
return PyRefJitModule(m);
},
py::arg("module"), py::arg("shared_libs"))
.def(
"invoke",
[](PyRefJitModule &self, std::string functionName,
std::vector<py::buffer> inputs) {
py::object ioListObject =
py::cast(PyRefJitValueList(npcompRefJitValueListCreate()));
PyRefJitValueList &ioList =
py::cast<PyRefJitValueList &>(ioListObject);
// Prepare inputs.
for (auto &buffer : inputs) {
// Request a C contiguous view as that is what Tensor accepts now
// (no strides or non row-major layout).
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
std::unique_ptr<Py_buffer> view(new Py_buffer());
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
throw py::error_already_set();
}
py::buffer_info info(view.release());
auto elementType =
mapBufferFormatToElementType(info.format, info.itemsize);
SmallVector<int32_t, 4> extents(info.shape.begin(),
info.shape.end());
npcompRefJitValueAddTensorCopy(ioList.instance, elementType,
extents.data(), extents.size(),
info.ptr);
}
// Invoke.
char *errorMessageCstr;
if (!npcompRefJitModuleInvoke(
self.instance,
MlirStringRef{functionName.data(), functionName.size()},
ioList.instance, &errorMessageCstr)) {
std::string errorMessage(errorMessageCstr);
std::free(errorMessageCstr);
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
}
// Prepare outputs.
std::vector<py::object> outputs;
for (intptr_t i = 0; i < npcompRefJitValueListSize(ioList.instance);
++i) {
if (npcompRefJitValueIsaTensor(ioList.instance, i)) {
NpcompRefJitElementType elementType;
intptr_t rank;
const int32_t *extents;
void *data = npcompRefJitValueGetTensor(
ioList.instance, i, &elementType, &rank, &extents);
const char *format;
switch (elementType) {
case NPCOMP_REFJIT_F32:
format = "f";
break;
default:
throw py::raiseValueError("unsupported tensor element type");
}
outputs.push_back(
py::array(py::dtype(format),
llvm::ArrayRef<std::int32_t>(extents, rank), data,
/*base=*/ioListObject));
} else {
throw py::raisePyError(PyExc_ValueError,
"unsupported npcomp refjit return type");
}
}
return outputs;
},
py::arg("function_name"), py::arg("inputs"));
}

View File

@ -1,50 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
"""Exports configuration for the package and settings for building libraries."""
import os
import platform
from ._mlir_libs import get_include_dirs, get_lib_dirs
__all__ = [
"get_include_dirs",
"get_lib_dirs",
]
def get_capi_link_library_name() -> str:
"""Gets the library name of the CAPI shared library to link against."""
return "NPCOMPPythonCAPI"
def get_capi_link_library_path() -> str:
"""Returns an absolute path to the CAPI shared library.
This should be preferred when seeking to create a non relocatable package
as it eliminates the possibility of interference of similar named libraries
on the link path.
Raises:
ValueError: If the library cannot be found.
"""
system = platform.system()
lib_prefix = "lib"
lib_suffix = ".so"
if system == "Darwin":
lib_suffix = ".dylib"
elif system == "Windows":
lib_prefix = ""
lib_suffix = ".lib"
lib_filename = f"{lib_prefix}{get_capi_link_library_name()}{lib_suffix}"
for lib_dir in get_lib_dirs():
full_path = os.path.join(lib_dir, lib_filename)
if os.path.exists(full_path): return full_path
raise ValueError(
f"Unable to find npcomp development library {lib_filename} in "
f"{get_lib_dirs()}")

View File

@ -1,66 +0,0 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
import os
import platform
_refjit = None
def get_refjit():
"""Dynamically resolves the refjit backend native module."""
global _refjit
if _refjit is not None:
return _refjit
from ...._mlir_libs import _npcomp as _cext
try:
imported_refjit = _cext.backend.refjit
except AttributeError:
raise ImportError(
"The npcomp native module was not compiled with refjit support")
_refjit = imported_refjit
return _refjit
def is_enabled() -> bool:
"""Returns whether the backend is enabled for the current build."""
try:
_get_refjit()
return True
except ImportError:
return False
def get_runtime_libs():
# The _refjit_resources directory is at the npcomp.compiler level.
resources_dir = os.path.join(os.path.dirname(__file__))
suffix = ".so"
if platform.system() == "Darwin":
suffix = ".dylib"
shlib_name = f"libNPCOMPCompilerRuntimeShlib{suffix}"
return [os.path.join(resources_dir, shlib_name)]
class JitModuleInvoker:
"""Wrapper around a native JitModule for calling functions."""
def __init__(self, jit_module):
super().__init__()
self._jit_module = jit_module
def __getattr__(self, function_name):
return self.__getitem__(function_name)
def __getitem__(self, function_name):
def invoke(*args):
results = self._jit_module.invoke(function_name, args)
if len(results) == 1:
# De-tuple.
return results[0]
else:
return tuple(results)
invoke.__isnpcomp__ = True
return invoke

View File

@ -8,7 +8,6 @@ configure_lit_site_cfg(
set(NPCOMP_TEST_DEPENDS
FileCheck count not
npcomp-opt
refback-run
NPCOMPPythonModules
)

View File

@ -1,8 +0,0 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
// CHECK-LABEL: @alloc_memref
func @alloc_memref(%arg0: tensor<?xindex>) {
// CHECK: refback.alloc_memref
%0 = refback.alloc_memref %arg0 : memref<?xf32>
return
}

View File

@ -1,23 +0,0 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
refbackrt.module_metadata {
// expected-error @+1 {{must reference a valid func}}
refbackrt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// -----
refbackrt.module_metadata {
// expected-error @+1 {{must agree on number of inputs}}
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
func @f() { return }
// -----
refbackrt.module_metadata {
// expected-error @+1 {{must agree on number of outputs}}
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
}
func @f() { return }

View File

@ -1,19 +0,0 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
// CHECK: refbackrt.module_metadata
refbackrt.module_metadata {
// CHECK: refbackrt.func_metadata
// TODO(brycearden): Encode unranked memrefs in the ABI
refbackrt.func_metadata {
funcName = @f,
numInputs = 1 : i32,
numOutputs = 0 : i32,
inputArgTypes = dense<1> : tensor<1xi32>,
inputElementTypes = dense<1> : tensor<1xi32>,
inputRanks = dense<-1> : tensor<1xi32>,
inputShapes = dense<1> : tensor<4xi32>}
}
func @f(%arg0: tensor<*xf32>) {
return
}

View File

@ -1,31 +0,0 @@
// RUN: npcomp-opt -split-input-file -lower-alloc-memref-ops <%s | FileCheck %s
// CHECK-LABEL: func @basic
func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
// CHECK: %[[I:.*]] = constant 0 : index
// CHECK: %[[E:.*]] = tensor.extract %arg0[%[[I]]]
// CHECK: alloc(%[[E]])
%0 = refback.alloc_memref %arg0 : memref<?xf32>
return %0 : memref<?xf32>
}
// -----
// CHECK-LABEL: func @all_static
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
// CHECK-NOT: tensor.extract
// CHECK: alloc()
%0 = refback.alloc_memref %arg0 : memref<3x4x5xf32>
return %0 : memref<3x4x5xf32>
}
// -----
// CHECK-LABEL: func @some_static
func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
// CHECK-DAG: %[[I1:.*]] = constant 1 : index
// CHECK-DAG: %[[E1:.*]] = tensor.extract %arg0[%[[I1]]]
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
// CHECK-DAG: %[[E3:.*]] = tensor.extract %arg0[%[[I3]]]
// CHECK: alloc(%[[E1]], %[[E3]])
%0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
return %0 : memref<3x?x5x?x7xf32>
}

View File

@ -1,69 +0,0 @@
// RUN: npcomp-opt -lower-to-refbackrt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
// Test module metadata.
// CHECK: refbackrt.module_metadata
// CHECK-NEXT: refbackrt.func_metadata
// CHECK-SAME: funcName = @f_2inputs_0outputs
// CHECK-SAME: numInputs = 2
// CHECK-SAME: numOutputs = 0
// CHECK-NEXT: refbackrt.func_metadata
// CHECK-SAME: funcName = @f_1input_2outputs
// CHECK-SAME: numInputs = 1
// CHECK-SAME: numOutputs = 2
// This function only exists to test its metadata above.
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
return
}
// This function only exists to test its metadata above.
func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
return %arg0, %arg0 : memref<?xf32>, memref<?xf32>
}
// -----
// Test ABI conversions.
// CHECK-LABEL: func @identity(%arg0: memref<*xf32>) -> memref<*xf32>
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK: return %arg0
return %arg0 : memref<?xf32>
}
// -----
// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>)
func @use_of_arg(%arg0: memref<?xf32>) {
// CHECK-NEXT: %[[MEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
%c0 = constant 0 : index
%0 = memref.dim %arg0, %c0 : memref<?xf32>
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
// CHECK-NEXT: memref.dim %[[MEMREF]], %[[C0]] : memref<?xf32>
return
}
// -----
// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32>
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: %[[INMEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
br ^bb1(%arg0: memref<?xf32>)
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
^bb1(%bbarg: memref<?xf32>):
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref.cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
// CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32>
return %bbarg : memref<?xf32>
}
// -----
// Test diagnostics.
// expected-error @+1 {{func not expressible with refbackrt ABI}}
func @unhandled_abi_type_on_public_func(%arg0: i32) {
return
}

View File

@ -1,11 +0,0 @@
# RUN: %PYTHON %s 2>&1
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from npcomp import build
assert build.get_include_dirs()
assert build.get_lib_dirs()
print("CAPI Path:", build.get_capi_link_library_path())

View File

@ -49,9 +49,6 @@ config.test_source_root = os.path.dirname(__file__)
# test_exec_root: The root path where tests should be run.
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test')
config.npcomp_bin_dir = os.path.join(config.npcomp_obj_root, 'bin')
config.npcomp_runtime_shlib = os.path.join(
config.npcomp_obj_root, 'lib',
'libNPCOMPCompilerRuntimeShlib' + config.llvm_shlib_ext)
# Tweak the PATH and PYTHONPATH to include the tools dir.
npcomp_python_dir = "python" if config.npcomp_built_standalone else "tools/npcomp/python"
@ -67,8 +64,6 @@ tool_dirs = [
]
tools = [
'npcomp-opt',
'refback-run',
ToolSubst('%npcomp_runtime_shlib', config.npcomp_runtime_shlib),
]
llvm_config.add_tool_substitutions(tools, tool_dirs)

View File

@ -1,27 +0,0 @@
// RUN: refback-run %s \
// RUN: -invoke forward \
// RUN: -arg-value="dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>" \
// RUN: -arg-value="dense<[10.0, 20.0]> : tensor<2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK{LITERAL}: output #0: dense<[[1.100000e+01, 2.200000e+01], [1.300000e+01, 2.400000e+01]]> : tensor<2x2xf32>
#map0 = affine_map<(d0, d1) -> (d0, d1)>
#map1 = affine_map<(d0, d1) -> (d1)>
builtin.func @forward(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> {
%c1 = constant 1 : index
%c0 = constant 0 : index
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
%2 = tensor.dim %arg1, %c0 : tensor<?xf32>
%3 = cmpi eq, %1, %2 : index
assert %3, "mismatched size for broadcast"
%4 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
%5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>) outs(%4 : tensor<?x?xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%6 = addf %arg2, %arg3 : f32
linalg.yield %6 : f32
} -> tensor<?x?xf32>
return %5 : tensor<?x?xf32>
}

View File

@ -1,10 +0,0 @@
// RUN: refback-run %s \
// RUN: -invoke constant \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
func @constant() -> tensor<f32> {
%0 = constant dense<1.0> : tensor<f32>
return %0 : tensor<f32>
}

View File

@ -1,10 +0,0 @@
// RUN: refback-run %s \
// RUN: -invoke identity \
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
func @identity(%arg0: tensor<f32>) -> tensor<f32> {
return %arg0 : tensor<f32>
}

View File

@ -1,12 +0,0 @@
// RUN: not refback-run %s \
// RUN: -invoke expects_one_tensor \
// RUN: -arg-value="1.0 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
// CHECK-SAME: actual (provided by user): Float
// CHECK-SAME: expected (from compiler): kTensor
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %arg0 : tensor<?xf32>
}

View File

@ -1,9 +0,0 @@
// RUN: not refback-run %s \
// RUN: -invoke requires_one_input \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: invoking 'requires_one_input': expected 1 inputs
func @requires_one_input(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %arg0 : tensor<?xf32>
}

View File

@ -1,20 +0,0 @@
// RUN: refback-run %s \
// RUN: -invoke multi_output \
// RUN: -arg-value="dense<1.0> : tensor<1xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<2.000000e+00> : tensor<1xf32>
// CHECK: output #1: dense<2.000000e+00> : tensor<1xf32>
#map0 = affine_map<(d0) -> (d0)>
func @multi_output(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
%c0 = constant 0 : index
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
%1 = linalg.init_tensor [%0] : tensor<?xf32>
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) outs(%1 : tensor<?xf32>) {
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
%6 = addf %arg2, %arg3 : f32
linalg.yield %6 : f32
} -> tensor<?xf32>
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}

View File

@ -1,10 +0,0 @@
// RUN: refback-run %s \
// RUN: -invoke scalar_arg \
// RUN: -arg-value="2.5 : f32" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: 2.500000e+00 : f32
func @scalar_arg(%arg0: f32) -> f32 {
return %arg0 : f32
}

View File

@ -1,3 +1,2 @@
add_subdirectory(npcomp-lsp-server)
add_subdirectory(npcomp-opt)
add_subdirectory(refback-run)

View File

@ -18,19 +18,6 @@ npcomp-opt() {
"$@"
}
refback-run() {
# Helper for building and invoking refback-run.
#
# This also automatically builds and adds the npcomp runtime shared
# library.
#
# Usage:
# $ refback-run <regular refback-run options>
ninja -C "$build_dir" refback-run NPCOMPCompilerRuntimeShlib 1>&2 || return 1
"$build_dir/bin/refback-run" \
-shared-libs="${build_dir}/lib/libNPCOMPCompilerRuntimeShlib.so" "$@"
}
# Go to the root of your npcomp checkout.
alias npd="cd $td"
# Handy so that `npctest -v` runs lit with -v and thus prints out errors,

View File

@ -1,30 +0,0 @@
# refback-run is always linked dynamically as we want to distribute the
# binaries with the python packages for hacking/debugging.
get_property(dialect_libs GLOBAL PROPERTY NPCOMP_DIALECT_LIBS)
get_property(conversion_libs GLOBAL PROPERTY NPCOMP_CONVERSION_LIBS)
add_npcomp_executable(refback-run
refback-run.cpp
)
llvm_update_compile_flags(refback-run)
target_link_libraries(refback-run PRIVATE
NPCOMPCAPI
NPCOMPInitAll
MLIRAnalysis
MLIRIR
MLIRJitRunner
MLIRParser
MLIRSupport
NPCOMPInitAll
NPCOMPRefBackendJITHelpers
TorchMLIRInitAll
# TODO: Remove these in favor of interface deps.
${conversion_libs}
${dialect_libs}
)
add_dependencies(refback-run
NPCOMPCompilerRuntimeShlib
)

View File

@ -1,252 +0,0 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Utility binary for compiling and running code through the npcomp
// RefBackend. The accepted input is the npcomp backend contract
// (roughly, linalg-on-tensors + std).
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/InitAllDialects.h"
#include "mlir/InitAllPasses.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "npcomp-c/InitLLVM.h"
#include "npcomp/InitAll.h"
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
#include "torch-mlir/InitAll.h"
#include "llvm/Support/InitLLVM.h"
using namespace mlir;
using llvm::Error;
using llvm::Expected;
using llvm::StringError;
using llvm::Twine;
/// Wrap a string into an llvm::StringError.
static Error make_string_error(const Twine &message) {
return llvm::make_error<StringError>(message.str(),
llvm::inconvertibleErrorCode());
}
static Expected<refbackrt::Ref<refbackrt::Tensor>>
convertAttrToTensor(Attribute attr) {
auto type = attr.getType().dyn_cast<RankedTensorType>();
if (!type)
return make_string_error("unhandled argument type; must be a tensor type");
auto extents = llvm::to_vector<6>(llvm::map_range(
type.getShape(), [](int64_t x) { return static_cast<std::int32_t>(x); }));
auto elementType = type.getElementType();
auto denseFp = attr.dyn_cast<DenseFPElementsAttr>();
if (denseFp) {
if (elementType.isF32()) {
auto values = llvm::to_vector<100>(llvm::map_range(
denseFp, [](APFloat f) { return f.convertToFloat(); }));
return refbackrt::Tensor::create(
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
refbackrt::ElementType::F32, static_cast<void *>(values.data()));
}
} else {
return make_string_error("unhandled argument; must be dense floating-point");
}
return make_string_error("unhandled argument");
}
static Expected<float> convertAttrToFloat(Attribute attr) {
auto type = attr.getType().dyn_cast<FloatType>();
if (!type)
return make_string_error("converting an argument to float that is not a FloatType");
auto floatAttr = attr.dyn_cast<FloatAttr>();
return floatAttr.getValue().convertToFloat();
}
static Expected<SmallVector<refbackrt::RtValue, 6>>
createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context;
SmallVector<refbackrt::RtValue, 6> ret;
for (auto argValue : argValues) {
auto attr = parseAttribute(argValue, &context);
if (!attr)
return make_string_error(Twine("could not parse arg value: ") + argValue);
auto attrType = attr.getType();
if (attrType.isa<RankedTensorType>()) {
auto expectedTensor = convertAttrToTensor(attr);
if (!expectedTensor)
return expectedTensor.takeError();
ret.push_back(std::move(*expectedTensor));
} else if (attrType.isa<FloatType>()) {
auto expectedFloat = convertAttrToFloat(attr);
if (!expectedFloat)
return expectedFloat.takeError();
ret.push_back(refbackrt::RtValue(*expectedFloat));
}
}
return ret;
}
static Type convertToMLIRType(refbackrt::ElementType type, Builder &builder) {
switch (type) {
case refbackrt::ElementType::F32:
return builder.getF32Type();
default:
llvm_unreachable("unsupported dtype");
}
}
static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor,
Builder &builder) {
auto elementType = convertToMLIRType(tensor.getElementType(), builder);
SmallVector<int64_t, 6> extents;
for (int i = 0, e = tensor.getRank(); i < e; i++)
extents.push_back(tensor.getExtent(i));
return RankedTensorType::get(extents, elementType);
}
static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
Builder &builder) {
if (value.isTensor()) {
auto& tensor = *(value.toTensor());
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
switch (tensor.getElementType()) {
case refbackrt::ElementType::F32: {
SmallVector<float, 100> values;
auto *basePtr = tensor.getData<float>();
for (int i = 0, e = type.getNumElements(); i < e; i++)
values.push_back(basePtr[i]);
return DenseFPElementsAttr::get(type, values);
}
default:
llvm_unreachable("unsupported element type");
}
} else if (value.isFloat()) {
return builder.getF32FloatAttr(value.toFloat());
}
llvm_unreachable("unsupported type");
}
static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
MLIRContext context;
Builder builder(&context);
auto attr = convertToMLIRAttribute(value, builder);
attr.print(os);
}
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) {
os << "output #" << output.index() << ": ";
printOutput(output.value(), os);
os << "\n";
}
}
Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
std::string invokeFunction, ArrayRef<StringRef> argValues,
ArrayRef<StringRef> sharedLibs, bool optimize) {
OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context);
if (!moduleRef)
return make_string_error(Twine("could not open ") + mlirFile);
ModuleOp module = *moduleRef;
// Compile.
PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit);
applyPassManagerCLOptions(pm);
refback::JITModule::buildBackendCompilationPipeline(pm, optimize);
if (failed(pm.run(module))) {
return make_string_error(Twine("error compiling to jit backend"));
}
auto expectedJitModule =
refback::JITModule::fromCompiledModule(module, sharedLibs);
if (!expectedJitModule)
return expectedJitModule.takeError();
auto jitModule = std::move(*expectedJitModule);
auto expectedInputs = createInputs(argValues);
if (!expectedInputs)
return expectedInputs.takeError();
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
if (!expectedOutputs)
return expectedOutputs.takeError();
auto outputs = std::move(*expectedOutputs);
printOutputs(outputs, llvm::outs());
llvm::outs() << "SUCCESS\n";
return Error::success();
}
//===----------------------------------------------------------------------===//
// Main-related init and option parsing.
//===----------------------------------------------------------------------===//
namespace {
namespace cl = llvm::cl;
struct Options {
cl::opt<std::string> inputFile{
cl::Positional, cl::desc("the input .mlir file"), cl::init("-")};
cl::opt<std::string> invokeFunction{"invoke", cl::Required,
cl::desc("function to invoke")};
cl::list<std::string> argValues{"arg-value", cl::ZeroOrMore,
cl::desc("Arguments to the called function")};
cl::list<std::string> sharedLibs{"shared-libs", cl::ZeroOrMore,
cl::MiscFlags::CommaSeparated,
cl::desc("Libraries to link dynamically")};
cl::opt<bool> optimize{
"optimize", cl::Optional,
cl::desc("whether the refback pass pipeline should run optimizations"),
cl::init(false)};
};
} // namespace
int main(int argc, char **argv) {
mlir::registerMLIRContextCLOptions();
mlir::registerAsmPrinterCLOptions();
mlir::registerPassManagerCLOptions();
Options options;
llvm::cl::ParseCommandLineOptions(argc, argv, "npcomp compile+run utility\n");
mlir::DialectRegistry registry;
mlir::registerAllDialects(registry);
mlir::registerAllPasses();
mlir::NPCOMP::registerAllDialects(registry);
mlir::NPCOMP::registerAllPasses();
mlir::torch::registerAllDialects(registry);
mlir::torch::registerAllPasses();
MLIRContext context;
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
llvm::InitLLVM y(argc, argv);
npcompInitializeLLVMCodegen();
SmallVector<StringRef, 6> sharedLibs(options.sharedLibs.begin(),
options.sharedLibs.end());
SmallVector<StringRef, 6> argValues(options.argValues.begin(),
options.argValues.end());
Error error =
compileAndRun(options.inputFile, context, options.invokeFunction,
argValues, sharedLibs, options.optimize);
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),
[&exitCode](const llvm::ErrorInfoBase &info) {
llvm::errs() << "Error: ";
info.log(llvm::errs());
llvm::errs() << '\n';
exitCode = EXIT_FAILURE;
});
return exitCode;
}