mirror of https://github.com/llvm/torch-mlir
parent
f9c48d0b89
commit
a25163fbfa
|
@ -22,7 +22,6 @@ endif()
|
|||
# Options and settings
|
||||
#-------------------------------------------------------------------------------
|
||||
set(NPCOMP_MINIMUM_PYTHON_VERSION 3.6)
|
||||
option(NPCOMP_ENABLE_REFJIT "Enables the reference JIT backend." ON)
|
||||
set(NPCOMP_IREE_BUILDDIR "../iree-build" CACHE STRING "If building IREE, then setting this elects to build from a source directory (versus installed package)")
|
||||
|
||||
# Turn on -gsplit-dwarf if requested in debug builds.
|
||||
|
@ -201,11 +200,6 @@ set(NPCOMP_TABLEGEN_ARGS "")
|
|||
# Optional feature selection
|
||||
#-------------------------------------------------------------------------------
|
||||
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
add_compile_definitions(NPCOMP_ENABLE_REFJIT)
|
||||
message(STATUS "Reference JIT backend enabled")
|
||||
endif()
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# Subdirectories and aggregate testing targets.
|
||||
#-------------------------------------------------------------------------------
|
||||
|
|
|
@ -1,99 +0,0 @@
|
|||
//===-- npcomp-c/RefJITBackend.h - C API for the reference JIT ----*- C -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM
|
||||
// Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_C_REFJITBACKEND_H
|
||||
#define NPCOMP_C_REFJITBACKEND_H
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include "mlir-c/Pass.h"
|
||||
#include "mlir-c/Support.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Define opaque API structs.
|
||||
#define DEFINE_C_API_STRUCT(name, storage) \
|
||||
struct name { \
|
||||
storage *ptr; \
|
||||
}; \
|
||||
typedef struct name name
|
||||
|
||||
DEFINE_C_API_STRUCT(NpcompRefJitModule, void);
|
||||
DEFINE_C_API_STRUCT(NpcompRefJitValueList, void);
|
||||
|
||||
#undef DEFINE_C_API_STRUCT
|
||||
|
||||
// Must be kept in sync with C++ side.
|
||||
enum NpcompRefJitElementType {
|
||||
NPCOMP_REFJIT_NONE = 0,
|
||||
NPCOMP_REFJIT_F32 = 1,
|
||||
};
|
||||
|
||||
/// Populates a PassManager with a pipeline that performs backend compilation.
|
||||
/// The resulting module can be passed to npcompRefJitModuleCreate().
|
||||
MLIR_CAPI_EXPORTED void
|
||||
npcompRefJitBuildBackendCompilationPipeline(MlirPassManager passManager,
|
||||
bool optimize);
|
||||
|
||||
/// Creates a RefJit module from an MlirModule (as compiled from the above
|
||||
/// pipeline). On success, returns a !null NpcompRefJitModule. On failure,
|
||||
/// returns null and malloc() allocates an error message into *errorMessage.
|
||||
/// The caller must free these messages.
|
||||
MLIR_CAPI_EXPORTED NpcompRefJitModule
|
||||
npcompRefJitModuleCreate(MlirModule module, MlirStringRef *sharedLibs,
|
||||
intptr_t sharedLibsSize, char **errorMessage);
|
||||
|
||||
/// Whether the module is null.
|
||||
static inline bool npcompRefJitModuleIsNull(NpcompRefJitModule m) {
|
||||
return !m.ptr;
|
||||
}
|
||||
|
||||
/// Destroys a refjit module.
|
||||
MLIR_CAPI_EXPORTED void npcompRefJitModuleDestroy(NpcompRefJitModule module);
|
||||
|
||||
/// Invokes a function on a RefJit module. On success, returns true and malloc()
|
||||
/// and adds all outputs to the passed outputs list. On failure, returns false
|
||||
/// and populates *errorMessage with a malloc() allocated error message, which
|
||||
/// must be caller freed.
|
||||
MLIR_CAPI_EXPORTED bool
|
||||
npcompRefJitModuleInvoke(NpcompRefJitModule m, MlirStringRef functionName,
|
||||
NpcompRefJitValueList inputOutputs,
|
||||
char **errorMessage);
|
||||
|
||||
/// Creates an empty value list.
|
||||
MLIR_CAPI_EXPORTED NpcompRefJitValueList npcompRefJitValueListCreate();
|
||||
|
||||
/// Destroys a value list.
|
||||
MLIR_CAPI_EXPORTED void
|
||||
npcompRefJitValueListDestroy(NpcompRefJitValueList list);
|
||||
|
||||
/// Returns the size of the value list.
|
||||
MLIR_CAPI_EXPORTED intptr_t
|
||||
npcompRefJitValueListSize(NpcompRefJitValueList list);
|
||||
|
||||
/// Adds values to the list.
|
||||
MLIR_CAPI_EXPORTED void npcompRefJitValueAddTensorCopy(
|
||||
NpcompRefJitValueList list, NpcompRefJitElementType elementType,
|
||||
const int32_t *extents, intptr_t extentsSize, const void *data);
|
||||
|
||||
// Reads Tensor from a list.
|
||||
MLIR_CAPI_EXPORTED bool npcompRefJitValueIsaTensor(NpcompRefJitValueList list,
|
||||
intptr_t i);
|
||||
MLIR_CAPI_EXPORTED void *
|
||||
npcompRefJitValueGetTensor(NpcompRefJitValueList list, intptr_t i,
|
||||
NpcompRefJitElementType *elementType, intptr_t *rank,
|
||||
const int32_t **extents);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // NPCOMP_C_REFJITBACKEND_H
|
|
@ -1,4 +1,3 @@
|
|||
add_subdirectory(Backend)
|
||||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(RefBackend)
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
add_subdirectory(Refback)
|
||||
add_subdirectory(Refbackrt)
|
||||
add_subdirectory(TorchConversion)
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(IR)
|
|
@ -1 +0,0 @@
|
|||
add_mlir_dialect(RefbackOps refback)
|
|
@ -1,23 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACK_BASE
|
||||
#define REFBACK_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Refback_Dialect : Dialect {
|
||||
let name = "refback";
|
||||
let cppNamespace = "::mlir::NPCOMP::refback";
|
||||
let description = [{
|
||||
Ops used by the reference backend as part of its lowering.
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
#endif // REFBACK_BASE
|
|
@ -1,16 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H
|
||||
#define NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOpsDialect.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACK_IR_REFBACKDIALECT_H
|
|
@ -1,22 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
|
||||
#define NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
|
||||
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACK_IR_REFBACKOPS_H
|
|
@ -1,40 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACK_OPS
|
||||
#define REFBACK_OPS
|
||||
|
||||
include "npcomp/Dialect/Refback/IR/RefbackBase.td"
|
||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
class Refback_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Refback_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops related to bufferization.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> {
|
||||
let summary = "Allocates a memref of the given shape.";
|
||||
let description = [{
|
||||
Allocates a memref of the given shape.
|
||||
|
||||
This op is a convenience for creating a bunch of
|
||||
tensor.extract ops + std.alloc.
|
||||
}];
|
||||
let arguments = (ins Shape_ExtentTensorType:$shape);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
#endif // REFBACK_OPS
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(IR)
|
|
@ -1 +0,0 @@
|
|||
add_mlir_dialect(RefbackrtOps refbackrt)
|
|
@ -1,26 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACKRT_BASE
|
||||
#define REFBACKRT_BASE
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def Refbackrt_Dialect : Dialect {
|
||||
let name = "refbackrt";
|
||||
let cppNamespace = "::mlir::NPCOMP::refbackrt";
|
||||
let description = [{
|
||||
The `refbackrt` dialect is the IR manifestation for interaction with the
|
||||
reference backend runtime. It primarily serves as a layer that enapsulates the
|
||||
data structures and functions available in the runtime, and faciliates
|
||||
conversion to those conventions, such as by providing utilities for being
|
||||
lowered to the llvm dialect.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // #ifndef REFBACKRT_BASE
|
|
@ -1,31 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
||||
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
namespace refbackrt {
|
||||
|
||||
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static TensorType get(MLIRContext *context) { return Base::get(context); }
|
||||
};
|
||||
|
||||
} // namespace refbackrt
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
|
@ -1,20 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
||||
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
||||
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h.inc"
|
||||
|
||||
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
|
@ -1,119 +0,0 @@
|
|||
//===-------------------------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACKRT_OPS
|
||||
#define REFBACKRT_OPS
|
||||
|
||||
include "npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
class Refbackrt_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Refbackrt_Dialect, mnemonic, traits> {
|
||||
}
|
||||
|
||||
def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
|
||||
let summary = "Aborts if the predicate is true";
|
||||
let description = [{
|
||||
Aborts if the predicate is true.
|
||||
}];
|
||||
let arguments = (ins I1:$pred, StrAttr:$msg);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "$pred `,` $msg attr-dict";
|
||||
}
|
||||
|
||||
def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
|
||||
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
||||
]> {
|
||||
let summary = "Global metadata for the module";
|
||||
let description = [{
|
||||
This op contains a region containing refbackrt.func_metadata ops,
|
||||
which give information about the functions in the module. This allows
|
||||
the module to be introspected when it is loaded, such as looking up
|
||||
functions.
|
||||
Future uses are checking how many results functions should have, or
|
||||
what their argument types are expected to be to provide clean and safe
|
||||
errors when invocations fail.
|
||||
|
||||
TODO: Verify that there should be no more than one of these ops in a
|
||||
module.
|
||||
|
||||
This op is designed to hold a region, which makes it easy to convert to
|
||||
a single LLVM global with a single conversion pattern.
|
||||
}];
|
||||
let arguments = (ins);
|
||||
let results = (outs);
|
||||
let regions = (region SizedRegion<1>:$metadatas);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Refbackrt_ModuleMetadataTerminatorOp
|
||||
: Refbackrt_Op<"module_metadata_terminator",
|
||||
[Terminator, HasParent<"ModuleMetadataOp">]> {
|
||||
let summary = "Implicit terminator for ModuleMetadataOp's region";
|
||||
let arguments = (ins);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def Refbackrt_FuncMetadataOp
|
||||
: Refbackrt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> {
|
||||
let summary = "Runtime metadata for a single func";
|
||||
let description = [{
|
||||
Runtime metadata for a single func.
|
||||
|
||||
Contains type / shape information for arguments as described below:
|
||||
|
||||
* ArgType(s):
|
||||
Integer value from `CompilerDataStructures.h` for each argument
|
||||
indicating what type it is (e.g. Float, Int, Tensor, Dict, etc.)
|
||||
* ElementType(s):
|
||||
Certain input ArgType's also have an element type (e.g. Tensor<float>,
|
||||
List<int>, etc.)
|
||||
TODO(brycearden): Support nested types (e.g. List<Tensor<float>>)
|
||||
* Rank(s):
|
||||
Integer value indicating the rank for each argument.
|
||||
* Shape(s):
|
||||
Flattened hyper-rectangular representation of the shapes for each argument.
|
||||
Since each shape's size varies based on the Rank, we pad out the shapes
|
||||
to size kMaxRank to make ABI lowering easier. See `LowerToRefbackrtABI.cpp`
|
||||
for details.
|
||||
|
||||
Shapes Example:
|
||||
constexpr int kMaxRank = 6;
|
||||
// func @f(%arg0: f32, %arg1: tensor<5xf32>) would result in...
|
||||
inputShapes = dense<...> : tensor<12xi32>
|
||||
// 2 shapes with 6 elements each so that the LowerToLLVM pass
|
||||
// where only the first `rank` values in each shape are valid.
|
||||
//
|
||||
// can update the struct(s) by just grabbing a pointer at
|
||||
// %shape_ptr = %base + (kMaxRank * argIndex)
|
||||
}];
|
||||
let arguments = (ins
|
||||
FlatSymbolRefAttr:$funcName,
|
||||
I32Attr:$numInputs,
|
||||
I32Attr:$numOutputs,
|
||||
OptionalAttr<I32ElementsAttr>:$inputArgTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$inputElementTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$inputRanks,
|
||||
OptionalAttr<I32ElementsAttr>:$inputShapes,
|
||||
// I32ElementsAttr:$inputIsStatic,
|
||||
OptionalAttr<I32ElementsAttr>:$outputArgTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$outputElementTypes,
|
||||
OptionalAttr<I32ElementsAttr>:$outputRanks,
|
||||
OptionalAttr<I32ElementsAttr>:$outputShapes
|
||||
//I32ElementsAttr:$outputIsStatic
|
||||
);
|
||||
let results = (outs);
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // #ifndef REFBACKRT_OPS
|
|
@ -1,5 +0,0 @@
|
|||
set(LLVM_TARGET_DEFINITIONS Passes.td)
|
||||
mlir_tablegen(Passes.h.inc -gen-pass-decls)
|
||||
add_public_tablegen_target(NPCOMPRefBackendPassIncGen)
|
||||
|
||||
add_mlir_doc(Passes RefBackendPasses ./ -gen-pass-doc)
|
|
@ -1,54 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_JITRUNTIME_JITMODULE_H
|
||||
#define NPCOMP_JITRUNTIME_JITMODULE_H
|
||||
|
||||
#include "mlir/ExecutionEngine/ExecutionEngine.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
class PassManager;
|
||||
} // namespace mlir
|
||||
|
||||
namespace refback {
|
||||
// Wrapper around refbackrt data structures and a JITted module, facilitating
|
||||
// interaction.
|
||||
class JITModule {
|
||||
public:
|
||||
/// Populates a PassManager with a pipeline that performs backend compilation.
|
||||
/// The resulting module can be passed to fromCompiledModule().
|
||||
static void buildBackendCompilationPipeline(mlir::PassManager &pm,
|
||||
bool optimize = false);
|
||||
|
||||
/// Constructs a JITModule from a compiled Module.
|
||||
/// The module should be the result of having run the backend compilation
|
||||
/// pipeline successfully.
|
||||
static llvm::Expected<std::unique_ptr<JITModule>>
|
||||
fromCompiledModule(mlir::ModuleOp module,
|
||||
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
||||
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||
invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<refbackrt::RtValue> inputs);
|
||||
|
||||
private:
|
||||
JITModule();
|
||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||
refbackrt::ModuleDescriptor *descriptor;
|
||||
};
|
||||
} // namespace refback
|
||||
|
||||
#endif // NPCOMP_JITRUNTIME_JITMODULE_H
|
|
@ -1,9 +0,0 @@
|
|||
Utilities for compiling and running on the reference backend with a JIT.
|
||||
|
||||
The runtime itself lives in {include,lib}/RefBackend/Runtime, but since it
|
||||
is totally firewalled from the compiler codebase, it presents a fairly
|
||||
bare-bones interface (e.g. it doesn't use libSupport, can't use LLVM's JIT
|
||||
interfaces, etc.).
|
||||
|
||||
The interface provided in this directory uses standard LLVM conventions and
|
||||
freely relies on libSupport, JIT utilities, etc.
|
|
@ -1,40 +0,0 @@
|
|||
//===-- Passes.td - Pass definition file -------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_REFBACKEND_PASSES
|
||||
#define NPCOMP_REFBACKEND_PASSES
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
|
||||
let summary = "Lower constructs requiring runtime support to `refbackrt`";
|
||||
let description = [{
|
||||
We have a specialized dialect `refbackrt` which models our runtime's data
|
||||
structures, and function signatures (and presumably eventually, other
|
||||
ABI boundaries like external calls if we ever support it) will be
|
||||
converted.
|
||||
|
||||
The constructs requiring runtime support are:
|
||||
- function signatures / module metadata
|
||||
- error handling
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";
|
||||
}
|
||||
|
||||
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
|
||||
let summary = "Lower AllocMemRefOp's";
|
||||
let constructor = "mlir::NPCOMP::createLowerAllocMemRefOpsPass()";
|
||||
let dependentDialects = ["tensor::TensorDialect", "memref::MemRefDialect"];
|
||||
}
|
||||
|
||||
def LowerToLLVM : Pass<"refback-lower-to-llvm", "ModuleOp"> {
|
||||
let summary = "Lower everything to LLVM";
|
||||
let constructor = "mlir::NPCOMP::createLowerToLLVMPass();";
|
||||
}
|
||||
|
||||
#endif // NPCOMP_REFBACKEND_PASSES
|
|
@ -1,51 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_REFBACKEND_REFBACKEND_H
|
||||
#define NPCOMP_REFBACKEND_REFBACKEND_H
|
||||
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
/// Registers all RefBackend passes.
|
||||
void registerRefBackendPasses();
|
||||
|
||||
// Look in createRefBackendLoweringPipeline for more information about how these
|
||||
// passes fit together.
|
||||
//
|
||||
// Pass summaries are in Passes.td.
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToLLVMPass();
|
||||
|
||||
std::unique_ptr<Pass> createRestrictedCanonicalizerPass();
|
||||
|
||||
struct RefBackendLoweringPipelineOptions
|
||||
: public PassPipelineOptions<RefBackendLoweringPipelineOptions> {
|
||||
// If this option is true, then perform optimizations.
|
||||
// If this option is false, only do the bare minimum for correctness.
|
||||
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
|
||||
llvm::cl::init(false)};
|
||||
};
|
||||
|
||||
// The main pipeline that encapsulates the full RefBackend lowering.
|
||||
void createRefBackendLoweringPipeline(
|
||||
OpPassManager &pm, const RefBackendLoweringPipelineOptions &options);
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // namespace mlir
|
||||
|
||||
#endif // NPCOMP_REFBACKEND_REFBACKEND_H
|
|
@ -1,14 +0,0 @@
|
|||
Refbackrt (namespace `refbackrt`) is the runtime support library for the
|
||||
RefBackend backend. It is best practice to keep compiler and runtime code
|
||||
totally firewalled.
|
||||
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
As such, this directory should have NO DEPENDENCIES ON COMPILER CODE (no
|
||||
LLVM libSupport, etc.).
|
||||
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
|
||||
This will cause some duplication, but history has shown that this
|
||||
firewalling pays big dividends. In particular, compiler code very
|
||||
frequently has binary sizes that are simply unacceptable in runtime
|
||||
scenarios, such as MByte-sized dependencies like LLVM libSupport.
|
||||
Runtime code should fit in kBytes.
|
|
@ -1,101 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Trimmed down support classes intended to provide a familiar LLVM-like API,
|
||||
// but without actually pulling in the LLVM ones.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_RUNTIME_SUPPORT_H
|
||||
#define NPCOMP_RUNTIME_SUPPORT_H
|
||||
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
namespace refbackrt {
|
||||
class StringRef {
|
||||
public:
|
||||
StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){};
|
||||
// Construct from NUL-terminated C string.
|
||||
StringRef(const char *ptr) : ptr(ptr), length(std::strlen(ptr)) {}
|
||||
bool equals(StringRef other) {
|
||||
if (length != other.length)
|
||||
return false;
|
||||
return std::memcmp(ptr, other.ptr, length) == 0;
|
||||
}
|
||||
|
||||
const char* str() { return ptr; }
|
||||
|
||||
private:
|
||||
const char *ptr;
|
||||
std::size_t length;
|
||||
};
|
||||
inline bool operator==(StringRef lhs, StringRef rhs) { return lhs.equals(rhs); }
|
||||
inline bool operator!=(StringRef lhs, StringRef rhs) {
|
||||
return !operator==(lhs, rhs);
|
||||
}
|
||||
|
||||
template <typename T> class ArrayRef {
|
||||
public:
|
||||
ArrayRef(const T *ptr, std::size_t length) : ptr(ptr), length(length){};
|
||||
const T &operator[](std::size_t i) const {
|
||||
assert(i < length);
|
||||
return ptr[i];
|
||||
}
|
||||
const T *data() const { return ptr; }
|
||||
std::size_t size() const { return length; }
|
||||
|
||||
private:
|
||||
const T *ptr;
|
||||
std::size_t length;
|
||||
};
|
||||
|
||||
template <typename T> class MutableArrayRef {
|
||||
public:
|
||||
MutableArrayRef(T *ptr, std::size_t length) : ptr(ptr), length(length){};
|
||||
T &operator[](std::size_t i) {
|
||||
assert(i < length);
|
||||
return ptr[i];
|
||||
}
|
||||
T *data() const { return ptr; }
|
||||
std::size_t size() const { return length; }
|
||||
|
||||
private:
|
||||
T *ptr;
|
||||
std::size_t length;
|
||||
};
|
||||
|
||||
// Literally copied from MLIR.
|
||||
struct LogicalResult {
|
||||
enum ResultEnum { Success, Failure } value;
|
||||
LogicalResult(ResultEnum v) : value(v) {}
|
||||
};
|
||||
|
||||
inline LogicalResult success(bool isSuccess = true) {
|
||||
return LogicalResult{isSuccess ? LogicalResult::Success
|
||||
: LogicalResult::Failure};
|
||||
}
|
||||
|
||||
inline LogicalResult failure(bool isFailure = true) {
|
||||
return LogicalResult{isFailure ? LogicalResult::Failure
|
||||
: LogicalResult::Success};
|
||||
}
|
||||
|
||||
inline bool succeeded(LogicalResult result) {
|
||||
return result.value == LogicalResult::Success;
|
||||
}
|
||||
|
||||
inline bool failed(LogicalResult result) {
|
||||
return result.value == LogicalResult::Failure;
|
||||
}
|
||||
|
||||
} // namespace refbackrt
|
||||
|
||||
#endif // NPCOMP_RUNTIME_SUPPORT_H
|
|
@ -1,421 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the public-facing interface for interacting with the npcomp
|
||||
// runtime.
|
||||
//
|
||||
// This functionality is totally firewalled from the compiler codebase, so
|
||||
// even if things superficially look similar, remember that there are no
|
||||
// LLVM utilities here, memory allocation should be kept to a minimum, etc.
|
||||
//
|
||||
// npcomp/RefBackend/Runtime/Support.h provides some minimal LLVM-like support
|
||||
// code to keep the API familiar.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_RUNTIME_USERAPI_H
|
||||
#define NPCOMP_RUNTIME_USERAPI_H
|
||||
|
||||
#include "npcomp/RefBackend/Runtime/Support.h"
|
||||
#include <array>
|
||||
#include <atomic>
|
||||
#include <cstdlib>
|
||||
|
||||
namespace refbackrt {
|
||||
|
||||
struct RtValue;
|
||||
|
||||
// Base class for any RefCounted object type
|
||||
class RefTarget {
|
||||
protected:
|
||||
template <typename T> friend class Ref;
|
||||
mutable std::atomic<size_t> refCount;
|
||||
|
||||
constexpr RefTarget() noexcept : refCount(0) {}
|
||||
};
|
||||
|
||||
// Reference-counted handle to a type with a `refCount` member.
|
||||
template <typename T> class Ref {
|
||||
public:
|
||||
Ref() { ptr = nullptr; }
|
||||
// Creates a Ref and increments the refcount by 1.
|
||||
// rawPtr must be allocated with std::malloc.
|
||||
Ref(T *rawPtr) {
|
||||
assert(rawPtr->refCount >= 0 && "expected non-negative refcount to start!");
|
||||
ptr = rawPtr;
|
||||
incref(ptr);
|
||||
}
|
||||
Ref(const Ref &other) {
|
||||
ptr = other.ptr;
|
||||
incref(ptr);
|
||||
}
|
||||
Ref(Ref &&other) { ptr = other.takePtr(); }
|
||||
Ref &operator=(const Ref &other) {
|
||||
if (&other == this)
|
||||
return *this;
|
||||
decref(ptr);
|
||||
ptr = other.ptr;
|
||||
incref(ptr);
|
||||
return *this;
|
||||
}
|
||||
Ref &operator=(Ref &&other) {
|
||||
if (&other == this)
|
||||
return *this;
|
||||
decref(ptr);
|
||||
ptr = other.takePtr();
|
||||
return *this;
|
||||
}
|
||||
~Ref() { decref(ptr); }
|
||||
|
||||
T &operator*() const { return *ptr; }
|
||||
T *operator->() const { return ptr; }
|
||||
T *get() const { return ptr; }
|
||||
|
||||
T *takePtr() {
|
||||
auto *ret = ptr;
|
||||
ptr = nullptr;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int debugGetRefCount() { return ptr->refCount; }
|
||||
|
||||
private:
|
||||
friend struct RtValue;
|
||||
static void incref(T *ptr) {
|
||||
if (!ptr)
|
||||
return;
|
||||
ptr->refCount += 1;
|
||||
}
|
||||
|
||||
friend struct RtValue;
|
||||
static void decref(T *ptr) {
|
||||
if (!ptr)
|
||||
return;
|
||||
if (ptr->refCount.fetch_sub(1) == 1) {
|
||||
ptr->~T();
|
||||
std::free(static_cast<void *>(ptr));
|
||||
}
|
||||
}
|
||||
T *ptr;
|
||||
};
|
||||
|
||||
// The available data types.
|
||||
enum class ElementType : std::int32_t {
|
||||
NONE,
|
||||
F32,
|
||||
};
|
||||
std::int32_t getElementTypeByteSize(ElementType type);
|
||||
StringRef getElementTypeAsStringRef(ElementType type);
|
||||
|
||||
// Representation of a tensor.
|
||||
class Tensor : public RefTarget {
|
||||
public:
|
||||
// Due to tail-allocated objects, this struct should never be directly
|
||||
// constructed.
|
||||
Tensor() = delete;
|
||||
|
||||
// Create a Tensor with the given extents and element type, with a buffer
|
||||
// holding a copy of `data`.
|
||||
static Ref<Tensor> create(ArrayRef<std::int32_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
// Same as `create`, but returns a raw pointer.
|
||||
static Tensor *createRaw(ArrayRef<std::int32_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
|
||||
static Ref<Tensor> create(ArrayRef<std::int64_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
// Same as `create`, but returns a raw pointer.
|
||||
static Tensor *createRaw(ArrayRef<std::int64_t> extents,
|
||||
ElementType elementType, void *data);
|
||||
|
||||
ElementType getElementType() const { return elementType; }
|
||||
std::int32_t getRank() const { return rank; }
|
||||
void *getData() const { return data; }
|
||||
template <typename T> T *getData() const { return static_cast<T *>(data); }
|
||||
std::int32_t getExtent(int dimension) const {
|
||||
return getExtents()[dimension];
|
||||
}
|
||||
ArrayRef<std::int32_t> getExtents() const {
|
||||
auto extents = const_cast<Tensor *>(this)->getMutableExtents();
|
||||
return ArrayRef<std::int32_t>(extents.data(), extents.size());
|
||||
}
|
||||
// Returns the number of bytes occupied by the data representing this tensor.
|
||||
// The total allocated amount might be higher to allow e.g. for alignment
|
||||
// nudging.
|
||||
std::int32_t getDataByteSize() const;
|
||||
~Tensor() { std::free(allocatedPtr); }
|
||||
|
||||
private:
|
||||
MutableArrayRef<std::int32_t> getMutableExtents() {
|
||||
auto *tail = reinterpret_cast<std::int32_t *>(this + 1);
|
||||
return MutableArrayRef<std::int32_t>(tail, rank);
|
||||
}
|
||||
|
||||
ElementType elementType;
|
||||
// The number of dimensions of this Tensor.
|
||||
// There are `rank` tail-allocated std::int32_t values representing the
|
||||
// tensor extents.
|
||||
std::int32_t rank;
|
||||
// The buffer base.
|
||||
void *data;
|
||||
// The raw pointer returned by the allocator (currently assumed to be
|
||||
// malloc), suitable for freeing the buffer.
|
||||
void *allocatedPtr;
|
||||
|
||||
// Sizes are tail-allocated.
|
||||
};
|
||||
|
||||
// RtValue is a generic tagged union used to hold all value types
|
||||
// The tag determines the type, and the payload represents the stored
|
||||
// contents of an object. If an object is not trivially destructible,
|
||||
// then it must be refcounted and must have a refCount.
|
||||
#define NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||
_(None) \
|
||||
_(Bool) \
|
||||
_(Int) \
|
||||
_(Float) \
|
||||
_(Double)
|
||||
|
||||
#define NPCOMP_FORALL_REF_TAGS(_) _(Tensor)
|
||||
|
||||
#define NPCOMP_FORALL_TAGS(_) \
|
||||
NPCOMP_FORALL_PRIM_TAGS(_) \
|
||||
NPCOMP_FORALL_REF_TAGS(_)
|
||||
|
||||
struct RtValue final {
|
||||
|
||||
RtValue() : payload{0}, tag(Tag::None) {}
|
||||
|
||||
// Bool
|
||||
RtValue(bool b) : tag(Tag::Bool) { payload.asBool = b; }
|
||||
bool isBool() const { return Tag::Bool == tag; }
|
||||
bool toBool() const {
|
||||
assert(isBool());
|
||||
return payload.asBool;
|
||||
}
|
||||
|
||||
// Int
|
||||
RtValue(std::int64_t i) : tag(Tag::Int) { payload.asInt = i; }
|
||||
RtValue(std::int32_t i) : RtValue(static_cast<int64_t>(i)) {}
|
||||
bool isInt() const { return Tag::Int == tag; }
|
||||
int64_t toInt() const {
|
||||
assert(isInt());
|
||||
return payload.asInt;
|
||||
}
|
||||
|
||||
// Float
|
||||
RtValue(float f) : tag(Tag::Float) { payload.asFloat = f; }
|
||||
bool isFloat() const { return Tag::Float == tag; }
|
||||
float toFloat() const {
|
||||
assert(isFloat());
|
||||
return payload.asFloat;
|
||||
}
|
||||
|
||||
// Double
|
||||
RtValue(double d) : tag(Tag::Double) { payload.asDouble = d; }
|
||||
bool isDouble() const { return Tag::Double == tag; }
|
||||
double toDouble() const {
|
||||
assert(isDouble());
|
||||
return payload.asDouble;
|
||||
}
|
||||
|
||||
// Tensor
|
||||
RtValue(Ref<Tensor> tensor) : tag(Tag::Tensor) {
|
||||
payload.asVoidPtr = reinterpret_cast<void *>(tensor.takePtr());
|
||||
}
|
||||
bool isTensor() const { return Tag::Tensor == tag; }
|
||||
Ref<Tensor> toTensor() const {
|
||||
assert(isTensor());
|
||||
return Ref<Tensor>(reinterpret_cast<Tensor *>(payload.asVoidPtr));
|
||||
}
|
||||
|
||||
// Ref
|
||||
bool isRef() const {
|
||||
#define DEFINE_IS_REF(x) \
|
||||
if (is##x()) { \
|
||||
return true; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_IS_REF)
|
||||
#undef DEFINE_IS_REF
|
||||
return false;
|
||||
}
|
||||
|
||||
// Scalar
|
||||
bool isScalar() const {
|
||||
return isBool() || isInt() || isFloat() || isDouble();
|
||||
}
|
||||
|
||||
// RtValue (downcast)
|
||||
const RtValue &toRtValue() const { return *this; }
|
||||
RtValue &toRtValue() { return *this; }
|
||||
|
||||
// Stringify tag for debugging.
|
||||
StringRef tagKind() const {
|
||||
switch (tag) {
|
||||
#define DEFINE_CASE(x) \
|
||||
case Tag::x: \
|
||||
return #x;
|
||||
NPCOMP_FORALL_TAGS(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
}
|
||||
// TODO(brycearden): Print tag here
|
||||
return "InvalidTag!";
|
||||
}
|
||||
|
||||
RtValue(const RtValue &rhs) : RtValue(rhs.payload, rhs.tag) {
|
||||
if (isRef()) {
|
||||
#define DEFINE_INCREF(x) \
|
||||
if (is##x()) { \
|
||||
Ref<x>::incref(static_cast<x *>(payload.asVoidPtr)); \
|
||||
return; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_INCREF)
|
||||
#undef DEFINE_INCREF
|
||||
assert(false && "Unsupported RtValue type");
|
||||
}
|
||||
}
|
||||
RtValue(RtValue &&rhs) noexcept : RtValue() { swap(rhs); }
|
||||
|
||||
RtValue &operator=(RtValue &&rhs) & noexcept {
|
||||
RtValue(std::move(rhs)).swap(*this); // this also sets rhs to None
|
||||
return *this;
|
||||
}
|
||||
RtValue &operator=(RtValue const &rhs) & {
|
||||
RtValue(rhs).swap(*this);
|
||||
return *this;
|
||||
}
|
||||
|
||||
~RtValue() {
|
||||
if (isRef()) {
|
||||
#define DEFINE_DECREF(x) \
|
||||
if (is##x()) { \
|
||||
Ref<x>::decref(static_cast<x *>(payload.asVoidPtr)); \
|
||||
return; \
|
||||
}
|
||||
NPCOMP_FORALL_REF_TAGS(DEFINE_DECREF)
|
||||
#undef DEFINE_DECREF
|
||||
assert(false && "Unsupported RtValue type");
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void swap(RtValue &rhs) {
|
||||
std::swap(payload, rhs.payload);
|
||||
std::swap(tag, rhs.tag);
|
||||
}
|
||||
|
||||
// NOTE: Runtime tags are intentionally private.
|
||||
// Please use the helper functions above to query information about the type
|
||||
// of a RtValue.
|
||||
enum class Tag : std::uint32_t {
|
||||
#define DEFINE_TAG(x) x,
|
||||
NPCOMP_FORALL_TAGS(DEFINE_TAG)
|
||||
#undef DEFINE_TAG
|
||||
};
|
||||
|
||||
union Payload {
|
||||
bool asBool;
|
||||
int64_t asInt;
|
||||
float asFloat;
|
||||
double asDouble;
|
||||
void *asVoidPtr;
|
||||
};
|
||||
|
||||
RtValue(Payload pl, Tag tag) : payload(pl), tag(tag) {}
|
||||
|
||||
Payload payload;
|
||||
Tag tag;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module loading.
|
||||
// This is the main entry point that users interact with.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
enum class ArgType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kTensor,
|
||||
kF32,
|
||||
kF64,
|
||||
};
|
||||
StringRef getArgTypeAsStringRef(ArgType type);
|
||||
|
||||
// Maximum rank supported across the ABI boundary
|
||||
constexpr static int kMaxRank = 6;
|
||||
|
||||
struct InputArgInfo {
|
||||
// What type of argument this is
|
||||
ArgType argType;
|
||||
// Certain arg types also have an element type
|
||||
ElementType elementType;
|
||||
std::int32_t rank;
|
||||
std::array<std::int32_t, kMaxRank> extents;
|
||||
};
|
||||
|
||||
struct OutputArgInfo {
|
||||
// What type of argument this is
|
||||
ArgType argType;
|
||||
// Certain arg types also have an element type
|
||||
ElementType elementType;
|
||||
std::int32_t rank;
|
||||
std::array<std::int32_t, kMaxRank> extents;
|
||||
// TODO(brycearden): Add checks for whether output buffers alias to input
|
||||
// buffers and populate field(s) here indicating that case
|
||||
};
|
||||
|
||||
// Maximum input or output arity.
|
||||
constexpr static int kMaxArity = 20;
|
||||
|
||||
// Metadata for a particular function.
|
||||
struct FunctionMetadata {
|
||||
std::int32_t numInputs;
|
||||
std::int32_t numOutputs;
|
||||
|
||||
std::array<InputArgInfo, kMaxArity> inputArgInfos;
|
||||
std::array<OutputArgInfo, kMaxArity> outputArgInfos;
|
||||
};
|
||||
|
||||
// Opaque forward declaration of module descriptor type. This is the type
|
||||
// created by the compiler in the module binary.
|
||||
struct ModuleDescriptor;
|
||||
|
||||
// Verifies that the input RtValue arg types match what the user provides
|
||||
// matches the types we expect from the descriptors emitted by the
|
||||
// compiler.
|
||||
//
|
||||
// Returns failure if the input type(s) are not valid
|
||||
LogicalResult checkRtValueArgTypes(const RtValue &value,
|
||||
const InputArgInfo &info);
|
||||
|
||||
// Verifies that the input RtValue shapes matches what the user provides
|
||||
// matches the types we expect from the descriptors emitted by the
|
||||
// compiler.
|
||||
//
|
||||
// Returns failure if the input type(s) are not valid
|
||||
LogicalResult checkRtValueShapes(const RtValue &value,
|
||||
const InputArgInfo &info);
|
||||
|
||||
// Creates an RtValue of the right type from the output metadata
|
||||
// provided by the compiled module
|
||||
RtValue createRtValueFromOutputArgInfo(const OutputArgInfo &info);
|
||||
|
||||
// Low-level invocation API. The number of inputs and outputs should be correct
|
||||
// and match the results of getMetadata.
|
||||
void invoke(ModuleDescriptor *moduleDescriptor, StringRef functionName,
|
||||
ArrayRef<RtValue> inputs, MutableArrayRef<RtValue> outputs);
|
||||
|
||||
// Metadata for function `functionName`.
|
||||
//
|
||||
// Returns failure if functionName wasn't found.
|
||||
LogicalResult getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef functionName,
|
||||
FunctionMetadata &outMetadata);
|
||||
|
||||
} // namespace refbackrt
|
||||
|
||||
#endif // NPCOMP_RUNTIME_USERAPI_H
|
|
@ -6,7 +6,6 @@ set(LLVM_LINK_COMPONENTS
|
|||
|
||||
add_npcomp_library(NPCOMPCAPI
|
||||
InitLLVM.cpp
|
||||
RefJITBackend.cpp
|
||||
Registration.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
|
@ -14,8 +13,6 @@ add_npcomp_library(NPCOMPCAPI
|
|||
MLIRLLVMIR
|
||||
MLIRTargetLLVMIRExport
|
||||
NPCOMPInitAll
|
||||
NPCOMPRefBackendJITHelpers
|
||||
NPCOMPRuntime
|
||||
TorchMLIRTorchDialect
|
||||
TorchMLIRInitAll
|
||||
|
||||
|
|
|
@ -1,127 +0,0 @@
|
|||
//===- RefJITBackend.cpp - CAPI for RefJit --------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp-c/RefJITBackend.h"
|
||||
|
||||
#include "mlir/CAPI/IR.h"
|
||||
#include "mlir/CAPI/Pass.h"
|
||||
#include "mlir/CAPI/Wrap.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
#include "llvm/ADT/Optional.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace refback;
|
||||
using namespace refbackrt;
|
||||
|
||||
using ValueListCpp = SmallVector<RtValue, 4>;
|
||||
DEFINE_C_API_PTR_METHODS(NpcompRefJitModule, JITModule)
|
||||
DEFINE_C_API_PTR_METHODS(NpcompRefJitValueList, ValueListCpp)
|
||||
|
||||
static_assert(static_cast<int>(ElementType::F32) == NPCOMP_REFJIT_F32,
|
||||
"mismatched F32 mapping");
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
static Optional<T> checkError(llvm::Expected<T> &&expected,
|
||||
char **errorMessageCstr, Twine banner = {}) {
|
||||
if (LLVM_LIKELY(expected))
|
||||
return std::move(*expected);
|
||||
|
||||
std::string errorMessage;
|
||||
llvm::raw_string_ostream os(errorMessage);
|
||||
llvm::logAllUnhandledErrors(expected.takeError(), os, banner);
|
||||
os.flush();
|
||||
*errorMessageCstr = strdup(errorMessage.c_str());
|
||||
return llvm::None;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void npcompRefJitBuildBackendCompilationPipeline(MlirPassManager passManager,
|
||||
bool optimize) {
|
||||
JITModule::buildBackendCompilationPipeline(*unwrap(passManager), optimize);
|
||||
}
|
||||
|
||||
NpcompRefJitModule npcompRefJitModuleCreate(MlirModule moduleOp,
|
||||
MlirStringRef *sharedLibs,
|
||||
intptr_t sharedLibsSize,
|
||||
char **errorMessage) {
|
||||
SmallVector<llvm::StringRef> sharedLibsCpp;
|
||||
for (intptr_t i = 0; i < sharedLibsSize; ++i) {
|
||||
sharedLibsCpp.push_back(
|
||||
llvm::StringRef(sharedLibs[i].data, sharedLibs[i].length));
|
||||
}
|
||||
|
||||
auto refJitModuleCpp =
|
||||
checkError(JITModule::fromCompiledModule(unwrap(moduleOp), sharedLibsCpp),
|
||||
errorMessage, "error creating refjit module");
|
||||
if (!refJitModuleCpp)
|
||||
return {nullptr};
|
||||
return wrap(refJitModuleCpp->release());
|
||||
}
|
||||
|
||||
void npcompRefJitModuleDestroy(NpcompRefJitModule module) {
|
||||
delete unwrap(module);
|
||||
}
|
||||
|
||||
bool npcompRefJitModuleInvoke(NpcompRefJitModule m, MlirStringRef functionName,
|
||||
NpcompRefJitValueList inputOutputs,
|
||||
char **errorMessage) {
|
||||
ValueListCpp *ioList = unwrap(inputOutputs);
|
||||
auto results = checkError(
|
||||
unwrap(m)->invoke(llvm::StringRef(functionName.data, functionName.length),
|
||||
*ioList),
|
||||
errorMessage, "error invoking function");
|
||||
ioList->clear();
|
||||
if (!results)
|
||||
return false;
|
||||
|
||||
for (int i = 0, e = results->size(); i < e; ++i) {
|
||||
ioList->push_back(std::move((*results)[i]));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
NpcompRefJitValueList npcompRefJitValueListCreate() {
|
||||
return wrap(new ValueListCpp());
|
||||
}
|
||||
|
||||
void npcompRefJitValueListDestroy(NpcompRefJitValueList list) {
|
||||
delete unwrap(list);
|
||||
}
|
||||
|
||||
intptr_t npcompRefJitValueListSize(NpcompRefJitValueList list) {
|
||||
return unwrap(list)->size();
|
||||
}
|
||||
|
||||
void npcompRefJitValueAddTensorCopy(NpcompRefJitValueList list,
|
||||
NpcompRefJitElementType elementType,
|
||||
const int32_t *extents,
|
||||
intptr_t extentsSize, const void *data) {
|
||||
ElementType elementTypeCpp = static_cast<ElementType>(elementType);
|
||||
auto tensor =
|
||||
Tensor::create(refbackrt::ArrayRef<std::int32_t>(extents, extentsSize),
|
||||
elementTypeCpp, const_cast<void *>(data));
|
||||
unwrap(list)->push_back(std::move(tensor));
|
||||
}
|
||||
|
||||
bool npcompRefJitValueIsaTensor(NpcompRefJitValueList list, intptr_t i) {
|
||||
return (*unwrap(list))[i].isTensor();
|
||||
}
|
||||
|
||||
void *npcompRefJitValueGetTensor(NpcompRefJitValueList list, intptr_t i,
|
||||
NpcompRefJitElementType *elementType,
|
||||
intptr_t *rank, const int32_t **extents) {
|
||||
auto tensor = (*unwrap(list))[i].toTensor();
|
||||
*elementType = static_cast<NpcompRefJitElementType>(tensor->getElementType());
|
||||
*rank = tensor->getRank();
|
||||
*extents = tensor->getExtents().data();
|
||||
return tensor->getData();
|
||||
}
|
|
@ -3,7 +3,6 @@ add_subdirectory(CAPI)
|
|||
add_subdirectory(Conversion)
|
||||
add_subdirectory(Dialect)
|
||||
add_subdirectory(Interfaces)
|
||||
add_subdirectory(RefBackend)
|
||||
|
||||
################################################################################
|
||||
# Setup the initialization target.
|
||||
|
@ -28,11 +27,8 @@ add_npcomp_library(NPCOMPInitAll
|
|||
# Local depends
|
||||
NPCOMPCommonBackend
|
||||
NPCOMPIREEBackend
|
||||
NPCOMPRefBackend
|
||||
NPCOMPRefbackDialect
|
||||
TorchMLIRTorchDialect
|
||||
NPCOMPTorchConversionDialect
|
||||
NPCOMPRefbackrtDialect
|
||||
NPCOMPConversionPasses
|
||||
IREEDialectsIREEDialect
|
||||
|
||||
|
|
|
@ -1,3 +1 @@
|
|||
add_subdirectory(Refback)
|
||||
add_subdirectory(Refbackrt)
|
||||
add_subdirectory(TorchConversion)
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(IR)
|
|
@ -1,19 +0,0 @@
|
|||
add_npcomp_dialect_library(NPCOMPRefbackDialect
|
||||
RefbackDialect.cpp
|
||||
RefbackOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refback
|
||||
|
||||
DEPENDS
|
||||
MLIRRefbackOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
MLIRSideEffectInterfaces
|
||||
MLIRShape
|
||||
)
|
|
@ -1,42 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::refback;
|
||||
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOpsDialect.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RefbackDialect Dialect Interfaces
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
struct RefbackInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
BlockAndValueMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
BlockAndValueMapping &) const final {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void RefbackDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"
|
||||
>();
|
||||
addInterfaces<RefbackInlinerInterface>();
|
||||
}
|
|
@ -1,19 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::refback;
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"
|
|
@ -1 +0,0 @@
|
|||
add_subdirectory(IR)
|
|
@ -1,17 +0,0 @@
|
|||
add_npcomp_dialect_library(NPCOMPRefbackrtDialect
|
||||
RefbackrtDialect.cpp
|
||||
RefbackrtOps.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refbackrt
|
||||
|
||||
DEPENDS
|
||||
MLIRRefbackrtOpsIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRSupport
|
||||
)
|
|
@ -1,25 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::refbackrt;
|
||||
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.cpp.inc"
|
||||
|
||||
void RefbackrtDialect::initialize() {
|
||||
addOperations<
|
||||
#define GET_OP_LIST
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
|
||||
>();
|
||||
addTypes<TensorType>();
|
||||
}
|
|
@ -1,70 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinOps.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::refbackrt;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleMetadataOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
|
||||
p.printOptionalAttrDictWithKeyword(op->getAttrs());
|
||||
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
}
|
||||
|
||||
static ParseResult parseModuleMetadataOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
auto *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, llvm::None, llvm::None))
|
||||
return failure();
|
||||
ModuleMetadataOp::ensureTerminator(*body, parser.getBuilder(),
|
||||
result.location);
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncMetadataOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(FuncMetadataOp op) {
|
||||
auto *module = op->getParentOp()->getParentOp();
|
||||
auto func = dyn_cast_or_null<FuncOp>(
|
||||
SymbolTable::lookupSymbolIn(module, op.funcName()));
|
||||
if (!func)
|
||||
return op.emitError() << "must reference a valid func";
|
||||
|
||||
if (op.numInputs() != func.getNumArguments())
|
||||
return op.emitError() << "must agree on number of inputs";
|
||||
if (op.numOutputs() != func.getNumResults())
|
||||
return op.emitError() << "must agree on number of outputs";
|
||||
|
||||
if (op.numInputs() > 0) {
|
||||
if (op.numInputs() != op.inputArgTypes()->size()) {
|
||||
return op.emitError() << "number of inputTypes must match number of inputs";
|
||||
}
|
||||
}
|
||||
if (op.numOutputs() > 0) {
|
||||
if (op.numOutputs() != op.outputArgTypes()->size())
|
||||
return op.emitError() << "number of outputTypes must match number of outputs";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
|
|
@ -13,23 +13,17 @@
|
|||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Backend/IREE/Passes.h"
|
||||
#include "npcomp/Conversion/Passes.h"
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionDialect.h"
|
||||
#include "npcomp/Dialect/TorchConversion/Transforms/Passes.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||
// clang-format off
|
||||
registry.insert<refbackrt::RefbackrtDialect,
|
||||
refback::RefbackDialect,
|
||||
mlir::NPCOMP::TorchConversion::TorchConversionDialect,
|
||||
registry.insert<mlir::NPCOMP::TorchConversion::TorchConversionDialect,
|
||||
iree::IREEDialect>();
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::registerAllPasses() {
|
||||
mlir::NPCOMP::registerRefBackendPasses();
|
||||
mlir::NPCOMP::registerConversionPasses();
|
||||
mlir::NPCOMP::registerTorchConversionPasses();
|
||||
mlir::NPCOMP::IREEBackend::registerIREEBackendPasses();
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
add_subdirectory(Runtime)
|
||||
add_subdirectory(JITHelpers)
|
||||
|
||||
add_npcomp_library(NPCOMPRefBackend
|
||||
RefBackend.cpp
|
||||
LowerToLLVM.cpp
|
||||
LowerToRefbackrtABI.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/RefBackend
|
||||
|
||||
DEPENDS
|
||||
NPCOMPRefBackendPassIncGen
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRLinalg
|
||||
MLIRLinalgToLLVM
|
||||
MLIRLinalgTransforms
|
||||
MLIRMathToLLVM
|
||||
MLIRMathTransforms
|
||||
MLIRMemRefToLLVM
|
||||
MLIRSCFToStandard
|
||||
MLIRSCFTransforms
|
||||
MLIRShapeToStandard
|
||||
MLIRStandard
|
||||
MLIRStandardOpsTransforms
|
||||
MLIRStandardToLLVM
|
||||
MLIRTensorTransforms
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(NPCOMPRefBackend)
|
|
@ -1,16 +0,0 @@
|
|||
add_npcomp_library(NPCOMPRefBackendJITHelpers
|
||||
JITModule.cpp
|
||||
|
||||
ADDITIONAL_HEADER_DIRS
|
||||
${PROJECT_SRC_DIR}/include/npcomp/RefBackend/JITHelpers
|
||||
|
||||
LINK_COMPONENTS
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
NPCOMPRuntime
|
||||
NPCOMPRefBackend
|
||||
MLIRExecutionEngine
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(NPCOMPRefBackend)
|
|
@ -1,146 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
#include "mlir/ExecutionEngine/CRunnerUtils.h"
|
||||
#include "mlir/ExecutionEngine/OptUtils.h"
|
||||
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include <sstream>
|
||||
|
||||
using namespace refback;
|
||||
using namespace mlir;
|
||||
using llvm::Error;
|
||||
using llvm::Expected;
|
||||
using llvm::StringError;
|
||||
using llvm::Twine;
|
||||
|
||||
/// Wrap a string into an llvm::StringError.
|
||||
static Error make_string_error(const Twine &message) {
|
||||
return llvm::make_error<StringError>(message.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
JITModule::JITModule() {}
|
||||
|
||||
void JITModule::buildBackendCompilationPipeline(PassManager &pm,
|
||||
bool optimize) {
|
||||
NPCOMP::RefBackendLoweringPipelineOptions options;
|
||||
options.optimize = optimize;
|
||||
NPCOMP::createRefBackendLoweringPipeline(pm, options);
|
||||
}
|
||||
|
||||
llvm::Expected<std::unique_ptr<JITModule>>
|
||||
JITModule::fromCompiledModule(mlir::ModuleOp module,
|
||||
llvm::ArrayRef<llvm::StringRef> sharedLibs) {
|
||||
// Ensure LLVM Dialect -> LLVM IR translations are available.
|
||||
mlir::registerLLVMDialectTranslation(*module->getContext());
|
||||
// Build the JITModule.
|
||||
auto expectedEngine = ExecutionEngine::create(
|
||||
module, /*llvmModuleBuilder=*/nullptr,
|
||||
/*transformer=*/[](llvm::Module *) { return Error::success(); },
|
||||
/*jitCodeGenOptLevel=*/llvm::None, llvm::to_vector<6>(sharedLibs));
|
||||
if (!expectedEngine)
|
||||
return expectedEngine.takeError();
|
||||
std::unique_ptr<JITModule> ret(new JITModule);
|
||||
ret->engine = std::move(*expectedEngine);
|
||||
// Here we abuse mlir::ExecutionEngine a bit. It technically returns a
|
||||
// function pointer, but here we look up a module descriptor.
|
||||
auto expectedAddress = ret->engine->lookup("__npcomp_module_descriptor");
|
||||
if (!expectedAddress)
|
||||
return expectedAddress.takeError();
|
||||
ret->descriptor =
|
||||
reinterpret_cast<refbackrt::ModuleDescriptor *>(*expectedAddress);
|
||||
return std::move(ret);
|
||||
}
|
||||
|
||||
// Converter for bridging to refbackrt llvm-lookalike data structures.
|
||||
static refbackrt::StringRef toRefbackrt(llvm::StringRef s) {
|
||||
return refbackrt::StringRef(s.data(), s.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static refbackrt::ArrayRef<T> toRefbackrt(llvm::ArrayRef<T> a) {
|
||||
return refbackrt::ArrayRef<T>(a.data(), a.size());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
||||
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||
}
|
||||
|
||||
static std::string stringifyShape(refbackrt::ArrayRef<std::int32_t> extents) {
|
||||
static constexpr char kDynamicDimAsString[] = "?";
|
||||
std::stringstream ss;
|
||||
ss << "(";
|
||||
for (int i = 0, e = extents.size(); i < e; i++) {
|
||||
if (extents[i] < 0)
|
||||
ss << kDynamicDimAsString;
|
||||
else
|
||||
ss << extents[i];
|
||||
if (i != e - 1)
|
||||
ss << "x";
|
||||
}
|
||||
ss << ")";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
llvm::Expected<llvm::SmallVector<refbackrt::RtValue, 6>>
|
||||
JITModule::invoke(llvm::StringRef functionName,
|
||||
llvm::ArrayRef<refbackrt::RtValue> inputs) {
|
||||
refbackrt::FunctionMetadata metadata;
|
||||
if (refbackrt::failed(refbackrt::getMetadata(
|
||||
descriptor, toRefbackrt(functionName), metadata)))
|
||||
return make_string_error("unknown function: " + Twine(functionName));
|
||||
SmallVector<refbackrt::RtValue, 6> outputs(metadata.numOutputs);
|
||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||
return make_string_error("invoking '" + Twine(functionName) +
|
||||
"': expected " + Twine(metadata.numInputs) +
|
||||
" inputs");
|
||||
|
||||
// Verify user input types and shapes match what the compiler expects
|
||||
for (int i = 0; i < metadata.numInputs; i++) {
|
||||
auto &input = inputs[i];
|
||||
auto &inputArgInfo = metadata.inputArgInfos[i];
|
||||
if (refbackrt::failed(checkRtValueArgTypes(input, inputArgInfo)))
|
||||
return make_string_error(
|
||||
"invoking '" + Twine(functionName) +
|
||||
"': input argument type mismatch. actual (provided by user): " +
|
||||
Twine(inputs[i].tagKind().str()) + ", expected (from compiler): " +
|
||||
Twine(getArgTypeAsStringRef(inputArgInfo.argType).str()));
|
||||
if (refbackrt::failed(checkRtValueShapes(input, inputArgInfo)))
|
||||
return make_string_error(
|
||||
"invoking '" + Twine(functionName) + "': input shape mismatch (%arg" +
|
||||
Twine(i) + "). " + "actual (provided by user): " +
|
||||
stringifyShape(input.toTensor()->getExtents()) +
|
||||
", expected (from compiler): " +
|
||||
stringifyShape(refbackrt::ArrayRef<int32_t>(
|
||||
inputArgInfo.extents.data(), inputArgInfo.rank)));
|
||||
}
|
||||
|
||||
// Create the correct output RtValue based on FuncMetadata,
|
||||
// which contains the arg types (scalar, Tensor, etc.), element types (only
|
||||
// applicable if not scalar) and shapes (also only applicable if not scalar)
|
||||
//
|
||||
// Currently we have to give each RtValue an output type so that we know
|
||||
// how to pack / unpack the outputs properly across the ABI boundary in
|
||||
// refbackrt::invoke. As a result, we can't just rely on the default
|
||||
// construction of each output argument type (otherwise RtValue will have
|
||||
// Tag::kNone) currently without passing the ArgInfo structs down to the
|
||||
// Runtime level, so we deal with the output type creation here.
|
||||
for (int i = 0; i < metadata.numOutputs; i++) {
|
||||
outputs[i] =
|
||||
refbackrt::createRtValueFromOutputArgInfo(metadata.outputArgInfos[i]);
|
||||
}
|
||||
|
||||
refbackrt::invoke(
|
||||
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
|
||||
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
||||
return outputs;
|
||||
}
|
|
@ -1,747 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/Math/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using mlir::LLVM::LLVMArrayType;
|
||||
using mlir::LLVM::LLVMFuncOp;
|
||||
using mlir::LLVM::LLVMFunctionType;
|
||||
using mlir::LLVM::LLVMPointerType;
|
||||
using mlir::LLVM::LLVMStructType;
|
||||
using mlir::LLVM::LLVMVoidType;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Descriptor types shared with the runtime.
|
||||
//
|
||||
// These correspond to the types in CompilerDataStructures.h
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// MaxRank that the refbackrt ABI lowering is capable of handling
|
||||
// NOTE: This parameter must stay consistent with
|
||||
// `lib/RefBackend/LowerToRefbackrtABI.cpp`
|
||||
static constexpr int kMaxRank = 6;
|
||||
|
||||
static LLVMPointerType getInt8PointerType(MLIRContext *context) {
|
||||
return LLVMPointerType::get(IntegerType::get(context, 8));
|
||||
}
|
||||
|
||||
static LLVMPointerType getInt32PointerType(MLIRContext *context) {
|
||||
return LLVMPointerType::get(IntegerType::get(context, 32));
|
||||
}
|
||||
|
||||
static LLVMStructType getInputDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// ArgType
|
||||
IntegerType::get(context, 32),
|
||||
// ElementType
|
||||
IntegerType::get(context, 32),
|
||||
// Rank
|
||||
IntegerType::get(context, 32),
|
||||
// Extents
|
||||
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||
// IsStatic
|
||||
// IntegerType::get(context, 32),
|
||||
});
|
||||
}
|
||||
|
||||
static LLVMStructType getOutputDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// ArgType
|
||||
IntegerType::get(context, 32),
|
||||
// ElementType
|
||||
IntegerType::get(context, 32),
|
||||
// Rank
|
||||
IntegerType::get(context, 32),
|
||||
// Extents
|
||||
LLVMPointerType::get(IntegerType::get(context, 32)),
|
||||
// IsStatic
|
||||
// IntegerType::get(context, 32),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVM type for refbackrt::FuncDescriptor.
|
||||
static LLVMStructType getFuncDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// Name length.
|
||||
IntegerType::get(context, 32),
|
||||
// Name chars.
|
||||
getInt8PointerType(context),
|
||||
// Type-erased function pointer.
|
||||
getInt8PointerType(context),
|
||||
// Number of inputs.
|
||||
IntegerType::get(context, 32),
|
||||
// Number of outputs.
|
||||
IntegerType::get(context, 32),
|
||||
// Argument descriptors
|
||||
LLVMPointerType::get(getInputDescriptorTy(context)),
|
||||
// Result Descriptors
|
||||
LLVMPointerType::get(getOutputDescriptorTy(context)),
|
||||
});
|
||||
}
|
||||
|
||||
// Get the LLVM type for refbackrt::ModuleDescriptor.
|
||||
static LLVMStructType getModuleDescriptorTy(MLIRContext *context) {
|
||||
return LLVMStructType::getLiteral(
|
||||
context, {
|
||||
// std::int32_t numFuncDescriptors;
|
||||
IntegerType::get(context, 32),
|
||||
// FuncDescriptor *functionDescriptors;
|
||||
LLVMPointerType::get(getFuncDescriptorTy(context)),
|
||||
});
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compiler runtime functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
class TrivialCompilerRuntimeLowering : public OpConversionPattern<T> {
|
||||
public:
|
||||
TrivialCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||
: OpConversionPattern<T>(backingFunc.getContext()),
|
||||
backingFunc(backingFunc) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(T op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc, operands);
|
||||
return success();
|
||||
}
|
||||
LLVM::LLVMFuncOp backingFunc;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
||||
OpBuilder &builder, Location loc) {
|
||||
// TODO: Deduplicate strings.
|
||||
std::string msgNulTerminated = msg.getValue().str();
|
||||
msgNulTerminated.push_back('\0');
|
||||
auto arrayTy = LLVMArrayType::get(IntegerType::get(module.getContext(), 8),
|
||||
msgNulTerminated.size());
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.setInsertionPointToStart(module.getBody());
|
||||
|
||||
// To get a unique symbol name, use a suffix derived from the current number
|
||||
// of ops in the module.
|
||||
// We can't use the SymbolTable's logic for this because the module
|
||||
// transiently contains a `func` and `llvm.func` with the same name during
|
||||
// conversion, preventing us from instantiating a SymbolTable.
|
||||
std::string symbolName =
|
||||
(Twine("__npcomp_string_") +
|
||||
Twine(llvm::size(llvm::to_vector<6>(module.getOps<LLVM::GlobalOp>()))))
|
||||
.str();
|
||||
auto globalOp = builder.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal, symbolName,
|
||||
builder.getStringAttr(msgNulTerminated));
|
||||
return globalOp;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class AbortIfOpCompilerRuntimeLowering
|
||||
: public OpConversionPattern<refbackrt::AbortIfOp> {
|
||||
public:
|
||||
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||
: OpConversionPattern<refbackrt::AbortIfOp>(backingFunc.getContext()),
|
||||
backingFunc(backingFunc) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(refbackrt::AbortIfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
refbackrt::AbortIfOp::Adaptor adaptor(operands);
|
||||
auto *context = op.getContext();
|
||||
|
||||
// Create the global string, take its address, and gep to get an `i8*`.
|
||||
auto globalOp = createGlobalString(op->getParentOfType<ModuleOp>(),
|
||||
op.msgAttr(), rewriter, op.getLoc());
|
||||
auto msgArray = rewriter.create<LLVM::AddressOfOp>(op.getLoc(), globalOp);
|
||||
auto c0 = rewriter.create<LLVM::ConstantOp>(op.getLoc(),
|
||||
IntegerType::get(context, 32),
|
||||
rewriter.getI32IntegerAttr(0));
|
||||
auto msg =
|
||||
rewriter.create<LLVM::GEPOp>(op.getLoc(), getInt8PointerType(context),
|
||||
msgArray, ValueRange({c0, c0}));
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(
|
||||
op, backingFunc, ValueRange({adaptor.pred(), msg}));
|
||||
return success();
|
||||
}
|
||||
LLVM::LLVMFuncOp backingFunc;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Create the LLVM runtime function backing the refbackrt op with name `name`
|
||||
// and requiring `type`.
|
||||
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, Type type,
|
||||
OpBuilder &builder,
|
||||
Location loc) {
|
||||
assert(type.isa<LLVMFunctionType>());
|
||||
std::string symbolName = (Twine("__npcomp_compiler_rt_") + name).str();
|
||||
return builder.create<LLVM::LLVMFuncOp>(loc, symbolName, type,
|
||||
LLVM::Linkage::External);
|
||||
}
|
||||
|
||||
static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||
RewritePatternSet &patterns,
|
||||
LLVMTypeConverter &typeConverter) {
|
||||
auto *context = module.getContext();
|
||||
OpBuilder builder(module.getBodyRegion());
|
||||
|
||||
{
|
||||
auto abortIfFuncTy = LLVMFunctionType::get(
|
||||
LLVMVoidType::get(context),
|
||||
{IntegerType::get(context, 1), getInt8PointerType(context)},
|
||||
/*isVarArg=*/false);
|
||||
LLVMFuncOp abortIfFunc = createCompilerRuntimeFuncDecl(
|
||||
"abort_if", abortIfFuncTy, builder, module.getLoc());
|
||||
patterns.add<AbortIfOpCompilerRuntimeLowering>(abortIfFunc);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering for module metadata
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LLVM::GlobalOp
|
||||
createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||
OpBuilder &builder, Location loc) {
|
||||
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
|
||||
|
||||
DenseMap<StringRef, LLVM::GlobalOp> globalsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> inputDescriptorsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> outputDescriptorsByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> inputShapesByName;
|
||||
DenseMap<StringRef, LLVM::GlobalOp> outputShapesByName;
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
auto arrayTy = LLVMArrayType::get(IntegerType::get(builder.getContext(), 8),
|
||||
funcMetadata.funcName().size());
|
||||
std::string llvmSymbolName =
|
||||
(Twine("__npcomp_internal_constant_") + funcMetadata.funcName()).str();
|
||||
auto global = builder.create<LLVM::GlobalOp>(
|
||||
loc, arrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
llvmSymbolName, builder.getStringAttr(funcMetadata.funcName()));
|
||||
globalsByName[funcMetadata.funcName()] = global;
|
||||
|
||||
// Create constants for the input / output shapes
|
||||
if (funcMetadata.inputShapes().hasValue()) {
|
||||
auto i32ArrayInputSymbolName =
|
||||
(Twine("__npcomp_internal_constant_input_shapes_") +
|
||||
funcMetadata.funcName())
|
||||
.str();
|
||||
auto inputNumElements = funcMetadata.inputShapes()->getNumElements();
|
||||
auto inputI32ArrayTy =
|
||||
LLVMArrayType::get(builder.getIntegerType(32), inputNumElements);
|
||||
auto inputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, inputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
i32ArrayInputSymbolName,
|
||||
/*value=*/funcMetadata.inputShapes().getValue());
|
||||
|
||||
inputShapesByName[funcMetadata.funcName()] = inputShapesGlobal;
|
||||
}
|
||||
|
||||
if (funcMetadata.outputShapes().hasValue()) {
|
||||
auto i32ArrayOutputSymbolName =
|
||||
(Twine("__npcomp_internal_constant_output_shapes_") +
|
||||
funcMetadata.funcName())
|
||||
.str();
|
||||
auto outputNumElements = funcMetadata.outputShapes()->getNumElements();
|
||||
auto outputI32ArrayTy =
|
||||
LLVMArrayType::get(builder.getIntegerType(32), outputNumElements);
|
||||
auto outputShapesGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, outputI32ArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
i32ArrayOutputSymbolName,
|
||||
/*value=*/funcMetadata.outputShapes().getValue());
|
||||
|
||||
outputShapesByName[funcMetadata.funcName()] = outputShapesGlobal;
|
||||
}
|
||||
}
|
||||
|
||||
auto updateDescriptor = [&](Value &descriptor, Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
descriptor = builder.create<LLVM::InsertValueOp>(
|
||||
loc, descriptor, value,
|
||||
/*position=*/builder.getI32ArrayAttr(position));
|
||||
};
|
||||
auto updateDescriptorWithI32Attr =
|
||||
[&](Value &descriptor, Attribute attr,
|
||||
std::initializer_list<int32_t> position) {
|
||||
auto constant = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty, attr);
|
||||
updateDescriptor(descriptor, constant, position);
|
||||
};
|
||||
|
||||
// Create global input descriptors
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
std::string llvmInputSymbolName =
|
||||
(Twine("__npcomp_input_descriptors_") + funcMetadata.funcName()).str();
|
||||
auto inputDescriptorTy = getInputDescriptorTy(builder.getContext());
|
||||
auto inputDescriptorArrayTy =
|
||||
LLVMArrayType::get(inputDescriptorTy, funcMetadata.numInputs());
|
||||
auto inputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, inputDescriptorArrayTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, llvmInputSymbolName, /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&inputDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
|
||||
Value inputDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, inputDescriptorArrayTy);
|
||||
|
||||
for (int i = 0, e = funcMetadata.numInputs(); i < e; i++) {
|
||||
// Arg Type
|
||||
if (!funcMetadata.inputArgTypes().hasValue())
|
||||
funcMetadata.emitError()
|
||||
<< "numInputs > 0 but there are no inputArgTypes?";
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||
funcMetadata.inputArgTypes()->getValue(i),
|
||||
{i, 0});
|
||||
// Element Type
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray,
|
||||
funcMetadata.inputElementTypes()->getValue(i),
|
||||
{i, 1});
|
||||
|
||||
// Rank
|
||||
// auto inputShapesType =
|
||||
// funcMetadata.inputShapes()->getType().dyn_cast<ShapedType>();
|
||||
auto rank = funcMetadata.inputRanks()->getValue(i);
|
||||
updateDescriptorWithI32Attr(inputDescriptorArray, rank, {i, 2});
|
||||
|
||||
// Shape
|
||||
// Each shape array is derived by offseting of kMaxRank * arg index
|
||||
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||
loc, inputShapesByName[funcMetadata.funcName()]);
|
||||
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||
loc, IntegerType::get(builder.getContext(), 32),
|
||||
builder.getI32IntegerAttr(i * kMaxRank));
|
||||
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||
ValueRange({c0, cShapeOffset}));
|
||||
updateDescriptor(inputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, inputDescriptorArray);
|
||||
|
||||
inputDescriptorsByName[funcMetadata.funcName()] =
|
||||
std::move(inputDescriptorArrayGlobal);
|
||||
}
|
||||
|
||||
// Create global output descriptors
|
||||
for (auto funcMetadata : funcMetadatas) {
|
||||
std::string llvmOutputSymbolName =
|
||||
(Twine("__npcomp_output_descriptors_") + funcMetadata.funcName()).str();
|
||||
auto outputDescriptorTy = getOutputDescriptorTy(builder.getContext());
|
||||
auto outputDescriptorArrayTy =
|
||||
LLVMArrayType::get(outputDescriptorTy, funcMetadata.numOutputs());
|
||||
auto outputDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, outputDescriptorArrayTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, llvmOutputSymbolName, /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&outputDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
|
||||
Value outputDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, outputDescriptorArrayTy);
|
||||
|
||||
for (int i = 0, e = funcMetadata.numOutputs(); i < e; i++) {
|
||||
if (!funcMetadata.outputArgTypes().hasValue())
|
||||
funcMetadata.emitError()
|
||||
<< "numOutputs > 0 but there are no outputArgTypes?";
|
||||
// Arg Type
|
||||
updateDescriptorWithI32Attr(outputDescriptorArray,
|
||||
funcMetadata.outputArgTypes()->getValue(i),
|
||||
{i, 0});
|
||||
// Element Type
|
||||
updateDescriptorWithI32Attr(
|
||||
outputDescriptorArray, funcMetadata.outputElementTypes()->getValue(i),
|
||||
{i, 1});
|
||||
|
||||
// Rank
|
||||
// auto outputShapesType =
|
||||
// funcMetadata.outputShapes()->getType().dyn_cast<ShapedType>();
|
||||
auto rank = funcMetadata.outputRanks()->getValue(i);
|
||||
updateDescriptorWithI32Attr(outputDescriptorArray, rank, {i, 2});
|
||||
|
||||
// Shapes
|
||||
// Offset by kMaxRank * arg index
|
||||
auto extentsArray = builder.create<LLVM::AddressOfOp>(
|
||||
loc, outputShapesByName[funcMetadata.funcName()]);
|
||||
auto cShapeOffset = builder.create<LLVM::ConstantOp>(
|
||||
loc, IntegerType::get(builder.getContext(), 32),
|
||||
builder.getI32IntegerAttr(i * kMaxRank));
|
||||
auto extentsArrayPtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt32PointerType(builder.getContext()), extentsArray,
|
||||
ValueRange({c0, cShapeOffset}));
|
||||
updateDescriptor(outputDescriptorArray, extentsArrayPtr, {i, 3});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, outputDescriptorArray);
|
||||
|
||||
outputDescriptorsByName[funcMetadata.funcName()] =
|
||||
outputDescriptorArrayGlobal;
|
||||
}
|
||||
|
||||
// This must match FuncDescriptor in the runtime.
|
||||
auto funcDescriptorTy = getFuncDescriptorTy(builder.getContext());
|
||||
auto funcDescriptorArrayTy =
|
||||
LLVMArrayType::get(funcDescriptorTy, funcMetadatas.size());
|
||||
auto funcDescriptorArrayGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, funcDescriptorArrayTy, /*isConstant=*/true, LLVM::Linkage::Internal,
|
||||
"__npcomp_func_descriptors",
|
||||
/*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&funcDescriptorArrayGlobal.initializer());
|
||||
|
||||
auto c0 = builder.create<LLVM::ConstantOp>(loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(0));
|
||||
// Build the initializer.
|
||||
Value funcDescriptorArray =
|
||||
builder.create<LLVM::UndefOp>(loc, funcDescriptorArrayTy);
|
||||
|
||||
for (auto funcMetadataAndIndex : llvm::enumerate(funcMetadatas)) {
|
||||
auto funcMetadata = funcMetadataAndIndex.value();
|
||||
int32_t index = funcMetadataAndIndex.index();
|
||||
|
||||
// Name length.
|
||||
updateDescriptorWithI32Attr(
|
||||
funcDescriptorArray,
|
||||
builder.getI32IntegerAttr(funcMetadata.funcName().size()), {index, 0});
|
||||
|
||||
// Name chars.
|
||||
auto funcNameArray = builder.create<LLVM::AddressOfOp>(
|
||||
loc, globalsByName[funcMetadata.funcName()]);
|
||||
auto funcNamePtr = builder.create<LLVM::GEPOp>(
|
||||
loc, getInt8PointerType(builder.getContext()), funcNameArray,
|
||||
ValueRange({c0, c0}));
|
||||
updateDescriptor(funcDescriptorArray, funcNamePtr, {index, 1});
|
||||
|
||||
// Function pointer.
|
||||
//
|
||||
// We create this reference to the original function (and use a dummy i8*
|
||||
// type). We will fix this up after conversion to point at wrapper
|
||||
// functions that satisfy the ABI requirements.
|
||||
// The bitcast is required so that after conversion the inserted value is an
|
||||
// i8* as expected by the descriptor struct.
|
||||
auto funcAddress = builder.create<LLVM::AddressOfOp>(
|
||||
loc, getInt8PointerType(builder.getContext()), funcMetadata.funcName());
|
||||
auto typeErasedFuncAddress = builder.create<LLVM::BitcastOp>(
|
||||
loc, getInt8PointerType(builder.getContext()), funcAddress);
|
||||
updateDescriptor(funcDescriptorArray, typeErasedFuncAddress, {index, 2});
|
||||
|
||||
// Number of inputs.
|
||||
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||
funcMetadata.numInputsAttr(), {index, 3});
|
||||
|
||||
// Number of outputs.
|
||||
updateDescriptorWithI32Attr(funcDescriptorArray,
|
||||
funcMetadata.numOutputsAttr(), {index, 4});
|
||||
|
||||
// Input descriptors
|
||||
auto inputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||
loc, inputDescriptorsByName[funcMetadata.funcName()]);
|
||||
auto rawInputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, LLVMPointerType::get(getInputDescriptorTy(builder.getContext())),
|
||||
inputDescriptorsArrayAddress);
|
||||
updateDescriptor(funcDescriptorArray, rawInputDescriptorsPtr, {index, 5});
|
||||
|
||||
// Output descriptors
|
||||
auto outputDescriptorsArrayAddress = builder.create<LLVM::AddressOfOp>(
|
||||
loc, outputDescriptorsByName[funcMetadata.funcName()]);
|
||||
auto rawOutputDescriptorsPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, LLVMPointerType::get(getOutputDescriptorTy(builder.getContext())),
|
||||
outputDescriptorsArrayAddress);
|
||||
updateDescriptor(funcDescriptorArray, rawOutputDescriptorsPtr, {index, 6});
|
||||
}
|
||||
|
||||
builder.create<LLVM::ReturnOp>(loc, funcDescriptorArray);
|
||||
|
||||
return funcDescriptorArrayGlobal;
|
||||
}
|
||||
|
||||
LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
||||
OpBuilder &builder, Location loc) {
|
||||
auto llvmI32Ty = IntegerType::get(builder.getContext(), 32);
|
||||
auto moduleDescriptorTy = getModuleDescriptorTy(builder.getContext());
|
||||
// TODO: Ideally this symbol name would somehow be related to the module
|
||||
// name, if we could consistently assume we had one.
|
||||
// TODO: We prepend _mlir so that mlir::ExecutionEngine's lookup logic (which
|
||||
// is typically only mean for function pointers) will find this raw symbol.
|
||||
auto moduleDescriptorGlobal = builder.create<LLVM::GlobalOp>(
|
||||
loc, moduleDescriptorTy, /*isConstant=*/true, LLVM::Linkage::External,
|
||||
"_mlir___npcomp_module_descriptor",
|
||||
/*value=*/Attribute());
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(&moduleDescriptorGlobal.initializer());
|
||||
|
||||
Value moduleDescriptor =
|
||||
builder.create<LLVM::UndefOp>(loc, moduleDescriptorTy);
|
||||
|
||||
auto updateDescriptor = [&](Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
moduleDescriptor = builder.create<LLVM::InsertValueOp>(
|
||||
loc, moduleDescriptor, value,
|
||||
/*position=*/builder.getI32ArrayAttr(position));
|
||||
};
|
||||
|
||||
updateDescriptor(builder.create<LLVM::ConstantOp>(
|
||||
loc, llvmI32Ty,
|
||||
builder.getI32IntegerAttr(funcDescriptorArray.getType()
|
||||
.cast<LLVMArrayType>()
|
||||
.getNumElements())),
|
||||
{0});
|
||||
|
||||
auto funcDecriptorArrayAddress =
|
||||
builder.create<LLVM::AddressOfOp>(loc, funcDescriptorArray);
|
||||
auto rawFuncDescriptorPtr = builder.create<LLVM::BitcastOp>(
|
||||
loc, LLVMPointerType::get(getFuncDescriptorTy(builder.getContext())),
|
||||
funcDecriptorArrayAddress);
|
||||
updateDescriptor(rawFuncDescriptorPtr, {1});
|
||||
builder.create<LLVM::ReturnOp>(loc, moduleDescriptor);
|
||||
|
||||
return moduleDescriptorGlobal;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerModuleMetadata
|
||||
: public OpConversionPattern<refbackrt::ModuleMetadataOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(refbackrt::ModuleMetadataOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto funcMetadatas =
|
||||
llvm::to_vector<6>(op.metadatas().getOps<refbackrt::FuncMetadataOp>());
|
||||
auto funcDescriptorArray =
|
||||
createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc());
|
||||
auto moduleDescriptor =
|
||||
createModuleDescriptor(funcDescriptorArray, rewriter, op.getLoc());
|
||||
|
||||
// TODO: create get module descriptor wrapper (or upgrade
|
||||
// mlir::ExecutionEngine to allow raw symbol lookup)
|
||||
(void)moduleDescriptor;
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Performs the calculation:
|
||||
// ```
|
||||
// ty *f(void **voidStarStar, int32_t i) {
|
||||
// return reinterpret_cast<ty *>(voidStarStar[i]);
|
||||
// }
|
||||
// ```
|
||||
static Value getTypedAddressFromVoidStarStar(Value voidStarStar, int32_t index,
|
||||
Type ty, OpBuilder &builder,
|
||||
Location loc) {
|
||||
Value ci = builder.create<LLVM::ConstantOp>(
|
||||
loc, IntegerType::get(builder.getContext(), 32),
|
||||
builder.getI32IntegerAttr(index));
|
||||
|
||||
// Do `voidStarStar[i]` as a gep + load.
|
||||
auto inputPtrAddr = builder.create<LLVM::GEPOp>(
|
||||
loc, LLVMPointerType::get(getInt8PointerType(builder.getContext())),
|
||||
voidStarStar, ValueRange(ci));
|
||||
auto inputPtr = builder.create<LLVM::LoadOp>(loc, inputPtrAddr);
|
||||
return builder.create<LLVM::BitcastOp>(loc, LLVMPointerType::get(ty),
|
||||
inputPtr);
|
||||
}
|
||||
|
||||
static SmallVector<Value, 6> loadCallArgs(Value inputsPtrPtr,
|
||||
LLVMFunctionType funcTy,
|
||||
OpBuilder &builder, Location loc) {
|
||||
SmallVector<Value, 6> callArgs;
|
||||
// For each void* in the void**, cast it to the right type and load it.
|
||||
for (int i = 0, e = funcTy.getNumParams(); i < e; i++) {
|
||||
auto paramTy = funcTy.getParamType(i);
|
||||
auto addr =
|
||||
getTypedAddressFromVoidStarStar(inputsPtrPtr, i, paramTy, builder, loc);
|
||||
callArgs.push_back(builder.create<LLVM::LoadOp>(loc, addr));
|
||||
}
|
||||
return callArgs;
|
||||
}
|
||||
|
||||
static Type getUnrankedMemrefDescriptorType(MLIRContext *context) {
|
||||
LLVMTypeConverter converter(context);
|
||||
// LLVMTypeConverter doesn't directly expose the struct type used to represent
|
||||
// unranked memrefs on ABI boundaries. To get that type, we convert
|
||||
// an unranked memref type and see what it produces.
|
||||
//
|
||||
// An unranked memref is just a size_t for the rank and an void* pointer to
|
||||
// descriptor, so the choice of element type here is arbitrary -- it all
|
||||
// converts to the same thing.
|
||||
return converter.convertType(
|
||||
UnrankedMemRefType::get(Float32Type::get(context),
|
||||
/*memorySpace=*/0));
|
||||
}
|
||||
|
||||
static Type getFloatType(MLIRContext *context) {
|
||||
LLVMTypeConverter converter(context);
|
||||
return converter.convertType(FloatType::getF32(context));
|
||||
}
|
||||
|
||||
// Writes out the logical results of the wrapper function through the void**
|
||||
// passed on the ABI boundary. Because LLVM (and hence llvm.func)
|
||||
// only supports a single return type (or void/no results), the logic here needs
|
||||
// to be aware of the convention used in the Std to LLVM conversion to map
|
||||
// multiple return types. The details of this are in the function
|
||||
// packFunctionResults and its callers:
|
||||
// https://github.com/llvm/llvm-project/blob/fad9cba8f58ba9979f390a49cf174ec9fcec29a6/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp#L282
|
||||
static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
|
||||
OpBuilder &builder, Location loc) {
|
||||
// 0 results. Nothing to do.
|
||||
if (callToWrapped.getNumResults() == 0)
|
||||
return;
|
||||
Value result = callToWrapped.getResult(0);
|
||||
auto ty = result.getType();
|
||||
|
||||
// 1 logical result.
|
||||
if (ty == getUnrankedMemrefDescriptorType(ty.getContext())) {
|
||||
Value addr =
|
||||
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||
return;
|
||||
} else if (ty == getFloatType(ty.getContext())) {
|
||||
Value addr =
|
||||
getTypedAddressFromVoidStarStar(resultsPtrPtr, 0, ty, builder, loc);
|
||||
builder.create<LLVM::StoreOp>(loc, result, addr);
|
||||
return;
|
||||
}
|
||||
assert(ty.isa<LLVMStructType>() && "must be a multi-result packed struct!");
|
||||
auto structType = ty.cast<LLVMStructType>();
|
||||
// >=2 logical results. The convention linked above will create a struct
|
||||
// wrapping.
|
||||
for (int i = 0, e = structType.getBody().size(); i < e; i++) {
|
||||
auto elementTy = structType.getBody()[i];
|
||||
Value addr = getTypedAddressFromVoidStarStar(resultsPtrPtr, i, elementTy,
|
||||
builder, loc);
|
||||
int32_t i32I = i;
|
||||
Value value = builder.create<LLVM::ExtractValueOp>(
|
||||
loc, elementTy, result, builder.getI32ArrayAttr({i32I}));
|
||||
builder.create<LLVM::StoreOp>(loc, value, addr);
|
||||
}
|
||||
}
|
||||
|
||||
// Construct a wrapper function.
|
||||
// For an externally visible function f(T1, T2) -> T3, T4, we create a
|
||||
// wrapper
|
||||
// __refbackrt_wrapper_f(void **inputs, void ** outputs) {
|
||||
// T3 t3;
|
||||
// T4 t4;
|
||||
// (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1]));
|
||||
// *cast<T3*>(outputs[0]) = t3;
|
||||
// *cast<T4*>(outputs[1]) = t4;
|
||||
// }
|
||||
// This is very similar to MLIR's "packed" convention, but supporting
|
||||
// outputs.
|
||||
// TODO: Extend MLIR's void** wrappers to have outputs in this way.
|
||||
static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
|
||||
auto *context = func.getContext();
|
||||
LLVMFunctionType funcTy = func.getType();
|
||||
auto voidStarTy = getInt8PointerType(context);
|
||||
auto voidStarStarTy = LLVMPointerType::get(voidStarTy);
|
||||
auto wrapperTy = LLVMFunctionType::get(LLVMVoidType::get(context),
|
||||
{voidStarStarTy, voidStarStarTy},
|
||||
/*isVarArg=*/false);
|
||||
constexpr char kRefbackrtWrapperPrefix[] = "__refbackrt_wrapper_";
|
||||
auto wrapperName = (Twine(kRefbackrtWrapperPrefix) + func.getName()).str();
|
||||
OpBuilder moduleBuilder(func->getParentRegion());
|
||||
LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>(
|
||||
func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External);
|
||||
|
||||
// Create the function body.
|
||||
Block &body = *wrapper.addEntryBlock();
|
||||
auto builder = OpBuilder::atBlockBegin(&body);
|
||||
auto callArgs =
|
||||
loadCallArgs(body.getArgument(0), funcTy, builder, func.getLoc());
|
||||
auto call = builder.create<LLVM::CallOp>(func.getLoc(), func, callArgs);
|
||||
storeWrapperResults(call, body.getArgument(1), builder, func.getLoc());
|
||||
builder.create<LLVM::ReturnOp>(func.getLoc(), ValueRange());
|
||||
return wrapper;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<LLVM::LLVMDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
||||
LLVMTypeConverter converter(context);
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
LLVMConversionTarget target(*context);
|
||||
populateCompilerRuntimePatterns(module, patterns, converter);
|
||||
target.addLegalOp<ModuleOp>();
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
populateMathToLLVMConversionPatterns(converter, patterns);
|
||||
populateMemRefToLLVMConversionPatterns(converter, patterns);
|
||||
patterns.add<LowerModuleMetadata>(context);
|
||||
|
||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||
// lots of these patterns.
|
||||
populateExpandTanhPattern(patterns);
|
||||
populateLinalgToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
// Rewrite llvm.mlir.addressof ops that reference the original exported
|
||||
// functions from the module to instead refer to wrapper functions.
|
||||
// These wrapper functions have a fixed ABI
|
||||
// (`void f(void **inputs, void **results)`) which we can interface to from
|
||||
// external code without dealing with platform-dependent
|
||||
// register-level calling conventions. We embed enough information in the
|
||||
// module metadata to make sure that calling code can e.g. preallocate
|
||||
// enough outputs and with the right types to safely funnel through this
|
||||
// convention.
|
||||
module.walk([&](LLVM::AddressOfOp op) {
|
||||
auto originalFunc =
|
||||
module.lookupSymbol<LLVM::LLVMFuncOp>(op.global_name());
|
||||
if (!originalFunc)
|
||||
return;
|
||||
auto wrapper = createWrapperFunc(originalFunc);
|
||||
op.getResult().setType(LLVMPointerType::get(wrapper.getType()));
|
||||
Builder builder(op.getContext());
|
||||
op->setAttr("global_name",
|
||||
SymbolRefAttr::get(builder.getContext(), wrapper.getName()));
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> mlir::NPCOMP::createLowerToLLVMPass() {
|
||||
return std::make_unique<LowerToLLVM>();
|
||||
}
|
|
@ -1,446 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "PassDetail.h"
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Verifier.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
// Since input/output shapes are not hyper-rectangular we specify
|
||||
// a maximum rank for each input shape such that shapes are padded
|
||||
// out to kMaxRank at the ABI boundary. That way we can represent
|
||||
// shapes using a traditional DenseElementsAttr.
|
||||
//
|
||||
// NOTE: When changing this parameter, also change the same `kMaxRank`
|
||||
// parameter in `lib/RefBackend/LowerToLLVM.cpp` so that the LLVM lowering
|
||||
// stays consistent.
|
||||
static constexpr int kMaxRank = 6;
|
||||
|
||||
// Get the type used to represent MemRefType `type` on ABI boundaries.
|
||||
// For convenience we do a cast to MemRefType internally.
|
||||
static Type getABIMemrefType(Type type) {
|
||||
return UnrankedMemRefType::get(type.cast<MemRefType>().getElementType(),
|
||||
/*memorySpace=*/0);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Creating module metadata.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Returns true if the function signature can be expressed with the refbackrt
|
||||
// ABI.
|
||||
static bool expressibleWithRefbackrtABI(FunctionType type) {
|
||||
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
|
||||
return llvm::all_of(
|
||||
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
||||
[](Type t) {
|
||||
return t.isa<UnrankedMemRefType, MemRefType, FloatType>();
|
||||
});
|
||||
}
|
||||
|
||||
// Returns the integer rerpresentation of the CompilerDataStructures::ABIType
|
||||
// Must stay aligned with CompilerDataStructures::ABIArgType enum
|
||||
static uint32_t getIntReprForABIType(Type type) {
|
||||
if (type.isa<MemRefType>() || type.isa<UnrankedMemRefType>()) {
|
||||
return 1;
|
||||
} else if (auto floatTy = type.dyn_cast<FloatType>()) {
|
||||
switch (floatTy.getWidth()) {
|
||||
case 32:
|
||||
return 2;
|
||||
case 64:
|
||||
return 3;
|
||||
default:
|
||||
assert(false && "Unsupported float bit width");
|
||||
}
|
||||
} else if (auto intTy = type.dyn_cast<IntegerType>()) {
|
||||
}
|
||||
// assert(false && "couldn't get IntReprForABIType");
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Must stay aligned with CompilerDataStructures::ABIElementType enum
|
||||
static uint32_t getIntReprForABIElementType(Type type) {
|
||||
if (auto shapedTy = type.dyn_cast<ShapedType>()) {
|
||||
auto elemTy = shapedTy.getElementType();
|
||||
if (auto floatTy = elemTy.dyn_cast<FloatType>()) {
|
||||
switch (floatTy.getWidth()) {
|
||||
case 32:
|
||||
return 1;
|
||||
default:
|
||||
assert(false && "Unsupported tensor element type");
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
static SmallVector<int32_t, kMaxRank>
|
||||
getExtentsForType(Type type, const int32_t maxRank = kMaxRank) {
|
||||
// Extend all shapes out to 4D to make our lives easier at the ABI boundary
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
if (!shapedType.hasRank()) {
|
||||
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||
}
|
||||
|
||||
auto shape = shapedType.getShape();
|
||||
auto shapeRank = shapedType.getRank();
|
||||
if (shapeRank <= maxRank) {
|
||||
SmallVector<int32_t, kMaxRank> extendedShape;
|
||||
// Push back all the values of the shape
|
||||
for (auto extentAndIndex : llvm::enumerate(shape)) {
|
||||
auto extent = extentAndIndex.value();
|
||||
auto index = extentAndIndex.index();
|
||||
if (shapedType.isDynamic(index)) {
|
||||
extendedShape.push_back(-1);
|
||||
} else {
|
||||
extendedShape.push_back(extent);
|
||||
}
|
||||
}
|
||||
|
||||
// Pad whatever is left so we have even vectors
|
||||
auto padRank = maxRank - shapeRank;
|
||||
for (int i = 0; i < padRank; i++)
|
||||
extendedShape.push_back(0xDEAD'BEEF);
|
||||
|
||||
return extendedShape;
|
||||
} else {
|
||||
assert(false && "unsupported rank");
|
||||
}
|
||||
}
|
||||
|
||||
// Represent Scalar's as all 1's.
|
||||
return {kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank, kMaxRank};
|
||||
}
|
||||
|
||||
int32_t getRankForType(Type type) {
|
||||
// Returns a rank of -1 if the tensor is unranked
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
return shapedType.hasRank() ? shapedType.getRank() : -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t hasStaticShape(Type type) {
|
||||
if (auto shapedType = type.dyn_cast<ShapedType>()) {
|
||||
return shapedType.hasStaticShape() ? 1 : 0;
|
||||
}
|
||||
// Assume scalars and non-shaped type things are static
|
||||
return 1;
|
||||
}
|
||||
|
||||
static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||
auto moduleMetadata =
|
||||
OpBuilder::atBlockBegin(module.getBody())
|
||||
.create<refbackrt::ModuleMetadataOp>(module.getLoc());
|
||||
moduleMetadata.metadatas().push_back(new Block);
|
||||
Block &metadatas = moduleMetadata.metadatas().front();
|
||||
OpBuilder::atBlockEnd(&metadatas)
|
||||
.create<refbackrt::ModuleMetadataTerminatorOp>(module.getLoc());
|
||||
|
||||
SymbolTable symbolTable(module);
|
||||
auto builder = OpBuilder::atBlockBegin(&metadatas);
|
||||
for (auto func : module.getOps<FuncOp>()) {
|
||||
if (symbolTable.getSymbolVisibility(func) !=
|
||||
SymbolTable::Visibility::Public) {
|
||||
continue;
|
||||
}
|
||||
// TODO: Add richer information here such as expected shapes and element
|
||||
// types.
|
||||
SmallVector<uint32_t, 6> inputABIArgTypes;
|
||||
SmallVector<uint32_t, 6> inputABIElementTypes;
|
||||
SmallVector<SmallVector<int32_t, kMaxRank>, 6> inputABIShapes;
|
||||
SmallVector<uint32_t, 6> inputABIRanks;
|
||||
// SmallVector<uint32_t, 6> inputIsStatic;
|
||||
for (const auto &inputArgType : func.getBody().front().getArgumentTypes()) {
|
||||
inputABIArgTypes.push_back(getIntReprForABIType(inputArgType));
|
||||
inputABIElementTypes.push_back(getIntReprForABIElementType(inputArgType));
|
||||
inputABIShapes.push_back(
|
||||
getExtentsForType(inputArgType, /*maxRank=*/kMaxRank));
|
||||
inputABIRanks.push_back(getRankForType(inputArgType));
|
||||
// inputIsStatic.push_back(hasStaticShape(inputArgType));
|
||||
}
|
||||
|
||||
SmallVector<uint32_t, 6> outputABIArgTypes;
|
||||
SmallVector<uint32_t, 6> outputABIElementTypes;
|
||||
SmallVector<SmallVector<int32_t, kMaxRank>, 6> outputABIShapes;
|
||||
SmallVector<uint32_t, 6> outputABIRanks;
|
||||
SmallVector<uint32_t, 6> outputIsStatic;
|
||||
for (const auto &outputArgType : func.getCallableResults()) {
|
||||
outputABIArgTypes.push_back(getIntReprForABIType(outputArgType));
|
||||
outputABIElementTypes.push_back(
|
||||
getIntReprForABIElementType(outputArgType));
|
||||
outputABIShapes.push_back(
|
||||
getExtentsForType(outputArgType, /*maxRank=*/kMaxRank));
|
||||
outputABIRanks.push_back(getRankForType(outputArgType));
|
||||
// outputIsStatic.push_back(hasStaticShape(outputArgType));
|
||||
}
|
||||
|
||||
auto i32Type = builder.getIntegerType(32);
|
||||
auto inputABIDataType =
|
||||
RankedTensorType::get(inputABIArgTypes.size(), i32Type);
|
||||
auto inputABIElementType =
|
||||
RankedTensorType::get(inputABIElementTypes.size(), i32Type);
|
||||
auto inputABIShapesType = RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>{static_cast<long>(inputABIShapes.size()) *
|
||||
kMaxRank},
|
||||
i32Type);
|
||||
auto inputABIRanksType =
|
||||
RankedTensorType::get(inputABIRanks.size(), i32Type);
|
||||
// auto inputIsStaticType = RankedTensorType::get(inputIsStatic.size(),
|
||||
// i32Type);
|
||||
auto outputABIDataType =
|
||||
RankedTensorType::get(outputABIArgTypes.size(), i32Type);
|
||||
auto outputABIElementType =
|
||||
RankedTensorType::get(outputABIElementTypes.size(), i32Type);
|
||||
auto outputABIShapesType = RankedTensorType::get(
|
||||
llvm::ArrayRef<int64_t>{static_cast<long>(outputABIShapes.size()) *
|
||||
kMaxRank},
|
||||
i32Type);
|
||||
auto outputABIRanksType =
|
||||
RankedTensorType::get(outputABIRanks.size(), i32Type);
|
||||
// auto outputIsStaticType = RankedTensorType::get(outputIsStatic.size(),
|
||||
// i32Type);
|
||||
|
||||
// TODO(brycearden): I'm sure there's a cleaner way to do this
|
||||
auto flattenABIShapes =
|
||||
[](SmallVector<SmallVector<int32_t, kMaxRank>, 6> shapes) {
|
||||
SmallVector<int32_t, 32> ret;
|
||||
for (auto &shape : shapes)
|
||||
for (auto &dim : shape)
|
||||
ret.push_back(dim);
|
||||
return ret;
|
||||
};
|
||||
|
||||
SmallVector<NamedAttribute, 16> namedAttrs;
|
||||
|
||||
// Add attributes that are valid for every func (funcName, numInputs,
|
||||
// numOutputs)
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("funcName", module.getContext()),
|
||||
SymbolRefAttr::get(builder.getContext(), func.getName())));
|
||||
namedAttrs.push_back(
|
||||
std::make_pair(Identifier::get("numInputs", module.getContext()),
|
||||
builder.getI32IntegerAttr(func.getNumArguments())));
|
||||
namedAttrs.push_back(
|
||||
std::make_pair(Identifier::get("numOutputs", module.getContext()),
|
||||
builder.getI32IntegerAttr(func.getNumResults())));
|
||||
|
||||
if (inputABIArgTypes.size()) {
|
||||
// Only add input information if there are inputs
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputArgTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIDataType,
|
||||
llvm::makeArrayRef(inputABIArgTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputElementTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIElementType,
|
||||
llvm::makeArrayRef(inputABIElementTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputRanks", func.getContext()),
|
||||
DenseIntElementsAttr::get(inputABIRanksType,
|
||||
llvm::makeArrayRef(inputABIRanks))));
|
||||
auto inputShapesFlattened = flattenABIShapes(inputABIShapes);
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("inputShapes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
inputABIShapesType,
|
||||
llvm::makeArrayRef(flattenABIShapes(inputABIShapes)))));
|
||||
}
|
||||
|
||||
if (outputABIArgTypes.size()) {
|
||||
// Only add output information if there are outptus
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputArgTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(outputABIDataType,
|
||||
llvm::makeArrayRef(outputABIArgTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputElementTypes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
outputABIElementType,
|
||||
llvm::makeArrayRef(outputABIElementTypes))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputRanks", func.getContext()),
|
||||
DenseIntElementsAttr::get(outputABIRanksType,
|
||||
llvm::makeArrayRef(outputABIRanks))));
|
||||
namedAttrs.push_back(std::make_pair(
|
||||
Identifier::get("outputShapes", func.getContext()),
|
||||
DenseIntElementsAttr::get(
|
||||
outputABIShapesType,
|
||||
llvm::makeArrayRef(flattenABIShapes(outputABIShapes)))));
|
||||
}
|
||||
|
||||
builder.create<refbackrt::FuncMetadataOp>(func.getLoc(), ArrayRef<Type>{},
|
||||
ArrayRef<Value>{}, namedAttrs);
|
||||
|
||||
if (!expressibleWithRefbackrtABI(func.getType()))
|
||||
return func.emitError() << "func not expressible with refbackrt ABI";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Dialect conversion.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerAssertOp : public OpConversionPattern<AssertOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
AssertOp::Adaptor adaptor(operands);
|
||||
// The refbackrt runtime function aborts if the argument is true, rather
|
||||
// than when it is false as an `assert` does. So negate the predicate (by
|
||||
// xor'ing with 1).
|
||||
auto c1 = rewriter.create<ConstantOp>(
|
||||
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
|
||||
APInt(/*numBits=*/1, /*val=*/1)));
|
||||
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
|
||||
rewriter.replaceOpWithNewOp<refbackrt::AbortIfOp>(op, assertFailed,
|
||||
op.msgAttr());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// At ABI boundaries, convert all memrefs to unranked memrefs so that they have
|
||||
// a fixed ABI.
|
||||
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(FuncOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
FunctionType type = op.getType();
|
||||
|
||||
TypeConverter::SignatureConversion entryConversion(type.getNumInputs());
|
||||
if (failed(typeConverter->convertSignatureArgs(type.getInputs(),
|
||||
entryConversion)))
|
||||
return rewriter.notifyMatchFailure(op, "could not convert inputs");
|
||||
SmallVector<Type, 1> newResultTypes;
|
||||
if (failed(typeConverter->convertTypes(type.getResults(), newResultTypes)))
|
||||
return rewriter.notifyMatchFailure(op, "could not convert outputs");
|
||||
|
||||
rewriter.updateRootInPlace(op, [&] {
|
||||
// Update the function type.
|
||||
op.setType(FunctionType::get(op.getContext(),
|
||||
entryConversion.getConvertedTypes(),
|
||||
newResultTypes));
|
||||
// Rewrite the entry block.
|
||||
Block &oldEntry = op.getBody().front();
|
||||
Block &newEntry =
|
||||
*rewriter.applySignatureConversion(&op.getBody(), entryConversion);
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(&newEntry);
|
||||
BlockArgument newArg, oldArg;
|
||||
for (auto newAndOldArg :
|
||||
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
|
||||
std::tie(newArg, oldArg) = newAndOldArg;
|
||||
auto memref = rewriter.create<memref::CastOp>(op.getLoc(), newArg,
|
||||
oldArg.getType());
|
||||
rewriter.replaceUsesOfBlockArgument(oldArg, memref);
|
||||
}
|
||||
});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
// At the return ABI boundaries, convert to the ABI type.
|
||||
// This pattern is needed to trigger the type conversion mechanics to do a
|
||||
// target materialization.
|
||||
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, operands);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult doDialectConversion(ModuleOp module) {
|
||||
auto *context = module.getContext();
|
||||
|
||||
TypeConverter typeConverter;
|
||||
typeConverter.addConversion([](Type type) { return type; });
|
||||
typeConverter.addConversion(
|
||||
[](MemRefType type) { return getABIMemrefType(type); });
|
||||
typeConverter.addTargetMaterialization(
|
||||
[](OpBuilder &builder, UnrankedMemRefType type, ValueRange inputs,
|
||||
Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
return builder.create<memref::CastOp>(
|
||||
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
|
||||
});
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addLegalDialect<refbackrt::RefbackrtDialect>();
|
||||
target.addLegalDialect<StandardOpsDialect>();
|
||||
target.addLegalDialect<memref::MemRefDialect>();
|
||||
|
||||
patterns.add<FuncOpSignatureConversion>(typeConverter, context);
|
||||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
||||
patterns.add<RewriteReturnOp>(typeConverter, context);
|
||||
target.addDynamicallyLegalOp<ReturnOp>(
|
||||
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
patterns.add<LowerAssertOp>(context);
|
||||
target.addIllegalOp<AssertOp>();
|
||||
|
||||
return applyPartialConversion(module, target, std::move(patterns));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||
// the refbackrt dialect.
|
||||
class LowerToRefbackrtABI
|
||||
: public LowerToRefbackrtABIBase<LowerToRefbackrtABI> {
|
||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||
registry.insert<refbackrt::RefbackrtDialect, memref::MemRefDialect>();
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
ModuleOp module = getOperation();
|
||||
|
||||
// Before we lower anything, capture any needed metadata about the argument
|
||||
// lists that will be needed for safely invoking the raw runtime functions
|
||||
// later. (for example, number of expected arguments/results, types,
|
||||
// etc.)
|
||||
if (failed(createModuleMetadata(module)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Now do the actual conversion / lowering.
|
||||
if (failed(doDialectConversion(module)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::createLowerToRefbackrtABIPass() {
|
||||
return std::make_unique<LowerToRefbackrtABI>();
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
//===- PassDetail.h - RefBackend Pass class details -------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef REFBACKEND_PASSDETAIL_H
|
||||
#define REFBACKEND_PASSDETAIL_H
|
||||
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "npcomp/RefBackend/Passes.h.inc"
|
||||
|
||||
} // namespace NPCOMP
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // REFBACKEND_PASSDETAIL_H
|
|
@ -1,234 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is the base file for npcomp's "reference backend".
|
||||
//
|
||||
// The input to this backend is a layer that consists of linalg-on-tensors
|
||||
// together with std scalar ops and control flow.
|
||||
//
|
||||
// The output of this backend is LLVM IR suitable for JITing.
|
||||
//
|
||||
// We expect that other backends will appear that have a similar kind of
|
||||
// interface. IREE already uses this layering.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/RefBackend/RefBackend.h"
|
||||
#include "PassDetail.h"
|
||||
|
||||
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
|
||||
#include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
|
||||
#include "mlir/Conversion/ShapeToStandard/ShapeToStandard.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
|
||||
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
|
||||
#include "mlir/Dialect/Linalg/Passes.h"
|
||||
#include "mlir/Dialect/MemRef/IR/MemRef.h"
|
||||
#include "mlir/Dialect/SCF/Passes.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/Dialect/Tensor/Transforms/Passes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassRegistry.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "npcomp/Backend/Common/Passes.h"
|
||||
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pass registration
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#define GEN_PASS_REGISTRATION
|
||||
#include "npcomp/RefBackend/Passes.h.inc"
|
||||
} // end namespace
|
||||
|
||||
void mlir::NPCOMP::registerRefBackendPasses() {
|
||||
::registerPasses();
|
||||
|
||||
mlir::PassPipelineRegistration<RefBackendLoweringPipelineOptions>(
|
||||
"refback-lowering-pipeline", "RefBackend lowering pipeline.",
|
||||
mlir::NPCOMP::createRefBackendLoweringPipeline);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerAllocMemRefOps
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerAllocMemRefOp : public OpRewritePattern<refback::AllocMemRefOp> {
|
||||
public:
|
||||
using OpRewritePattern::OpRewritePattern;
|
||||
LogicalResult matchAndRewrite(refback::AllocMemRefOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto memrefType = op.getType().cast<MemRefType>();
|
||||
auto shape = op.getOperand();
|
||||
// std.alloc only accepts the dynamic extents as operands, so only
|
||||
// collect those.
|
||||
SmallVector<Value, 6> dynamicExtents;
|
||||
for (int i = 0, e = memrefType.getRank(); i < e; i++) {
|
||||
if (memrefType.isDynamicDim(i)) {
|
||||
auto ci = rewriter.create<ConstantIndexOp>(op.getLoc(), i);
|
||||
auto extent = rewriter.create<tensor::ExtractOp>(op.getLoc(), shape,
|
||||
ValueRange({ci}));
|
||||
dynamicExtents.push_back(extent);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<memref::AllocOp>(op, memrefType,
|
||||
dynamicExtents);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerAllocMemRefOps
|
||||
: public LowerAllocMemRefOpsBase<LowerAllocMemRefOps> {
|
||||
|
||||
void runOnOperation() override {
|
||||
auto func = getOperation();
|
||||
auto *context = &getContext();
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add<LowerAllocMemRefOp>(context);
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<refback::AllocMemRefOp>();
|
||||
target.addLegalOp<tensor::ExtractOp>();
|
||||
target.addLegalOp<memref::AllocOp>();
|
||||
target.addLegalOp<ConstantOp>();
|
||||
if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
mlir::NPCOMP::createLowerAllocMemRefOpsPass() {
|
||||
return std::make_unique<LowerAllocMemRefOps>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// createRefBackendLoweringPipeline
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
||||
OpPassManager &pm, const RefBackendLoweringPipelineOptions &options) {
|
||||
|
||||
// Delete dead lists. RefBackend doesn't support lists, but in some cases
|
||||
// we can get by if all the lists are dead.
|
||||
pm.addNestedPass<FuncOp>(
|
||||
NPCOMP::CommonBackend::createDeleteDeadIREEListsPass());
|
||||
|
||||
// Convert all elementwise ops to linalg.
|
||||
//
|
||||
// Considering correctness, this lets us reuse the linalg bufferization, which
|
||||
// applies uniformly to all linalg structured ops.
|
||||
//
|
||||
// Also, converting to linalg herevopens up a lot of optimization
|
||||
// opportunities.
|
||||
pm.addNestedPass<FuncOp>(createConvertElementwiseToLinalgPass());
|
||||
|
||||
if (options.optimize) {
|
||||
pm.addNestedPass<FuncOp>(createLinalgElementwiseOpFusionPass());
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// Lower shape constraints before we enter tensor->memref conversion.
|
||||
// That is, we expand shape.cstr_* ops to eager error handling code.
|
||||
pm.addNestedPass<FuncOp>(createConvertShapeConstraintsPass());
|
||||
// Run shape canonicalizations. In particular, this erases shape.assuming,
|
||||
// now that we have converted shape constraints.
|
||||
// TODO: Don't canonicalize everything.
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
|
||||
// Lower shape ops to std.
|
||||
pm.addPass(createConvertShapeToStandardPass());
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Lower the `tensor` type to `memref`.
|
||||
// --------------------------------------------------------------------------
|
||||
// We make a conscious effort here to do this as a sequence of separate passes
|
||||
// rather than a single mega dialect conversion pass.
|
||||
//
|
||||
// This means that intermediate steps have source/target materializations
|
||||
// (memref.tensor_load / memref.buffer_cast) in the IR.
|
||||
|
||||
// Run tensor constant bufferization.
|
||||
// This pass has to run on a module op, and so does the final
|
||||
// FuncBufferizePass. But everything else can run in parallel on functions,
|
||||
// so we try to bracket the entire bufferization pipeline with the module
|
||||
// passes to allow maximum parallelism.
|
||||
pm.addPass(createTensorConstantBufferizePass());
|
||||
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
|
||||
// We need to resolve this to std.alloc which takes individual extents.
|
||||
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());
|
||||
pm.addNestedPass<FuncOp>(createSCFBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createLinalgBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createStdBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createTensorBufferizePass());
|
||||
pm.addPass(createFuncBufferizePass());
|
||||
pm.addNestedPass<FuncOp>(createFinalizingBufferizePass());
|
||||
|
||||
// TODO: Do buffer deallocation. We should be able to just drop in the
|
||||
// upstream pass?
|
||||
|
||||
// At this point, we have lots of loose stuff floating around from lowering,
|
||||
// so it's a good time to do some general cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Preparation for converting to an LLVM module.
|
||||
// --------------------------------------------------------------------------
|
||||
// Now, we begin the process of lowering to LLVM's level of abstraction
|
||||
// (after which LLVM will take over lowering to machine code).
|
||||
|
||||
// Lower linalg ops to loops.
|
||||
// TODO: Do some linalg optimizations like tiling here.
|
||||
pm.addNestedPass<FuncOp>(createConvertLinalgToLoopsPass());
|
||||
|
||||
// Run a some cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
// Final conversion to an LLVM module.
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
// Convert affine to std control flow in preparation for going to LLVM.
|
||||
pm.addNestedPass<FuncOp>(createLowerAffinePass());
|
||||
|
||||
// Convert scf to std control flow in preparation for going to LLVM.
|
||||
pm.addNestedPass<FuncOp>(createLowerToCFGPass());
|
||||
|
||||
// Convert functions signatures and other constructs that interface with the
|
||||
// runtime to the `refbackrt` dialect.
|
||||
pm.addPass(createLowerToRefbackrtABIPass());
|
||||
|
||||
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
||||
// which reuses the upstream patterns and gives us a place to add our own
|
||||
// patterns for our own custom ops like the refbackrt ops.
|
||||
pm.addPass(createLowerToLLVMPass());
|
||||
|
||||
// Although LLVM will clean everything up eventually, for the sake of IR
|
||||
// clarity while still in MLIR, run some cleanups.
|
||||
if (options.optimize) {
|
||||
pm.addNestedPass<FuncOp>(createCanonicalizerPass());
|
||||
pm.addNestedPass<FuncOp>(createCSEPass());
|
||||
}
|
||||
}
|
|
@ -1,36 +0,0 @@
|
|||
# Appease LLVM check that all sources are covered by a target.
|
||||
# This doesn't seem to play well with having multiple targets
|
||||
# in a single directory.
|
||||
set(LLVM_OPTIONAL_SOURCES
|
||||
Runtime.cpp
|
||||
CompilerRuntime.cpp
|
||||
)
|
||||
|
||||
# The library that users link against, defining basic interactions with an
|
||||
# refbackrt module and the relevant data structures.
|
||||
add_npcomp_library(NPCOMPRuntime
|
||||
Runtime.cpp
|
||||
)
|
||||
|
||||
mlir_check_all_link_libraries(NPCOMPRuntime)
|
||||
|
||||
# The library that defines the symbols that the compiler emits references
|
||||
# to.
|
||||
# Note: is uses some of the same facilities that the user API depends on,
|
||||
# we use a linker script to ensure that the shared library only exposes the
|
||||
# symbols the compiler needs.
|
||||
#
|
||||
# This is currently done as a shared library to make it suitable for being
|
||||
# loaded by mlir::ExecutionEngine. In e.g. an embedded scenario, we would
|
||||
# need to create a static library and link that into the binary.
|
||||
add_npcomp_library(NPCOMPCompilerRuntimeShlib
|
||||
SHARED
|
||||
CompilerRuntime.cpp
|
||||
|
||||
EXCLUDE_FROM_LIBNPCOMP
|
||||
)
|
||||
target_link_libraries(NPCOMPCompilerRuntimeShlib PRIVATE NPCOMPRuntime)
|
||||
if (UNIX AND NOT APPLE)
|
||||
set_target_properties(NPCOMPCompilerRuntimeShlib PROPERTIES LINK_FLAGS
|
||||
"-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/unix_version.script")
|
||||
endif()
|
|
@ -1,87 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains data structures (which we typically call "descriptors")
|
||||
// that are emitted by the compiler and must be kept in sync with the compiler
|
||||
// code that creates them in LowerToLLVM.cpp.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||
#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace refbackrt {
|
||||
|
||||
// All arguments are packed into this type-erased form for being invoked. See
|
||||
// LowerToLLVM.cpp for more details.
|
||||
typedef void ABIFunc(void **, void **);
|
||||
|
||||
enum class ABIArgType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kMemref,
|
||||
kF32,
|
||||
kF64,
|
||||
};
|
||||
|
||||
enum class ABIElementType : std::uint32_t {
|
||||
kNone = 0,
|
||||
kF32,
|
||||
};
|
||||
|
||||
struct InputDescriptor {
|
||||
ABIArgType abiType;
|
||||
ABIElementType elementType;
|
||||
|
||||
std::int32_t rank;
|
||||
std::int32_t* extents;
|
||||
|
||||
// TODO(brycearden): Change to bool at ABI boundary
|
||||
// std::int32_t isStatic;
|
||||
};
|
||||
|
||||
struct OutputDescriptor {
|
||||
ABIArgType abiType;
|
||||
ABIElementType elementType;
|
||||
|
||||
std::int32_t rank;
|
||||
std::int32_t* extents;
|
||||
|
||||
// TODO(brycearden): Change to bool at ABI boundary
|
||||
//std::int32_t isStatic;
|
||||
};
|
||||
|
||||
struct FuncDescriptor {
|
||||
// The length of the function name.
|
||||
std::int32_t nameLen;
|
||||
// The name of the function, to allow lookup.
|
||||
const char *name;
|
||||
// This is a raw function pointer to the function's entry point as
|
||||
// emitted by the compiler.
|
||||
ABIFunc *functionPtr;
|
||||
// The number of inputs to the function.
|
||||
std::int32_t numInputs;
|
||||
// The number of outputs of the function.
|
||||
std::int32_t numOutputs;
|
||||
// TODO: Add shape checking to arg / result descriptor(s)
|
||||
InputDescriptor *inputDescriptors;
|
||||
OutputDescriptor *outputDescriptors;
|
||||
};
|
||||
|
||||
// The top-level entry point of the module metadata emitted by the
|
||||
// compiler. Unlike all the other descriptors here, external code does handle
|
||||
// this type (albeit through an opaque pointer).
|
||||
struct ModuleDescriptor {
|
||||
std::int32_t numFuncDescriptors;
|
||||
FuncDescriptor *functionDescriptors;
|
||||
};
|
||||
|
||||
} // namespace refbackrt
|
||||
|
||||
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
|
@ -1,33 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Symbols referenced only by the compiler and which will be compiled into a
|
||||
// shared object that a JIT can load to provide those symbols.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <array>
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
|
||||
#include "CompilerDataStructures.h"
|
||||
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
||||
|
||||
using namespace refbackrt;
|
||||
|
||||
extern "C" {
|
||||
__attribute__((visibility("default"))) void
|
||||
__npcomp_compiler_rt_abort_if(bool b, const char *msg);
|
||||
}
|
||||
|
||||
void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
|
||||
if (b) {
|
||||
std::fprintf(stderr, "NPCOMP: aborting: %s\n", msg);
|
||||
std::exit(1);
|
||||
}
|
||||
}
|
|
@ -1,519 +0,0 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
||||
|
||||
#include "llvm/Support/ErrorHandling.h"
|
||||
|
||||
#include <array>
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#include "CompilerDataStructures.h"
|
||||
|
||||
using namespace refbackrt;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Memref descriptors for interacting with MLIR codegenerated code.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// These definitions are based on the ones in
|
||||
// `mlir/ExecutionEngine/CRunnerUtils.h` and the layouts need to be kept in
|
||||
// sync.
|
||||
//
|
||||
// Those definitions are flawed though because they are overly templated.
|
||||
struct MemrefDescriptor {
|
||||
void *allocatedPtr;
|
||||
void *dataPtr;
|
||||
std::int64_t offset;
|
||||
// Tail-allocated int64_t sizes followed by strides.
|
||||
MutableArrayRef<std::int64_t> getSizes(int assumedRank) {
|
||||
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
|
||||
return MutableArrayRef<std::int64_t>(tail, assumedRank);
|
||||
}
|
||||
MutableArrayRef<std::int64_t> getStrides(int assumedRank) {
|
||||
auto *tail = reinterpret_cast<std::int64_t *>(this + 1);
|
||||
return MutableArrayRef<std::int64_t>(tail + assumedRank, assumedRank);
|
||||
}
|
||||
|
||||
// Returns a malloc-allocated MemrefDescriptor with the specified extents and
|
||||
// default striding.
|
||||
static MemrefDescriptor *create(ArrayRef<std::int32_t> extents, void *data);
|
||||
|
||||
// Returns the number of elements in this MemrefDescriptor, assuming this
|
||||
// descriptor has rank `assumedRank`.
|
||||
std::int32_t getNumElements(int assumedRank) {
|
||||
if (assumedRank == 0)
|
||||
return 1;
|
||||
return getSizes(assumedRank)[0] * getStrides(assumedRank)[0];
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
struct UnrankedMemref {
|
||||
int64_t rank;
|
||||
MemrefDescriptor *descriptor;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
MemrefDescriptor *MemrefDescriptor::create(ArrayRef<std::int32_t> extents,
|
||||
void *data) {
|
||||
auto rank = extents.size();
|
||||
auto allocSize = sizeof(MemrefDescriptor) + sizeof(std::int64_t) * 2 * rank;
|
||||
auto *descriptor = static_cast<MemrefDescriptor *>(std::malloc(allocSize));
|
||||
descriptor->allocatedPtr = data;
|
||||
descriptor->dataPtr = data;
|
||||
descriptor->offset = 0;
|
||||
// Iterate in reverse, copying the dimension sizes (i.e. extents) and
|
||||
// calculating the strides for a standard dense layout.
|
||||
std::int64_t stride = 1;
|
||||
for (int i = 0, e = rank; i < e; i++) {
|
||||
auto revIdx = e - i - 1;
|
||||
descriptor->getSizes(rank)[revIdx] = extents[revIdx];
|
||||
descriptor->getStrides(rank)[revIdx] = stride;
|
||||
stride *= extents[revIdx];
|
||||
}
|
||||
return descriptor;
|
||||
}
|
||||
|
||||
static UnrankedMemref convertRefbackrtTensorToUnrankedMemref(Tensor *tensor) {
|
||||
auto byteSize = tensor->getDataByteSize();
|
||||
void *data = std::malloc(byteSize);
|
||||
std::memcpy(data, tensor->getData(), byteSize);
|
||||
auto *descriptor = MemrefDescriptor::create(tensor->getExtents(), data);
|
||||
return UnrankedMemref{tensor->getRank(), descriptor};
|
||||
}
|
||||
|
||||
static Tensor *convertUnrankedMemrefToRefbackrtTensor(
|
||||
std::int64_t rank, MemrefDescriptor *descriptor, ElementType elementType) {
|
||||
// Launder from std::int64_t to std::int32_t.
|
||||
auto extents64 = descriptor->getSizes(rank);
|
||||
constexpr int kMaxRank = 20;
|
||||
std::array<std::int32_t, kMaxRank> extents32Buf;
|
||||
for (int i = 0, e = extents64.size(); i < e; i++)
|
||||
extents32Buf[i] = extents64[i];
|
||||
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
||||
elementType, descriptor->dataPtr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Tensor
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
|
||||
std::int32_t ret = 1;
|
||||
for (int i = 0, e = extents.size(); i < e; i++) {
|
||||
ret *= extents[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
||||
switch (type) {
|
||||
case ElementType::NONE:
|
||||
return 0;
|
||||
case ElementType::F32:
|
||||
return 4;
|
||||
}
|
||||
llvm_unreachable("unsupported dtype");
|
||||
}
|
||||
|
||||
StringRef refbackrt::getElementTypeAsStringRef(ElementType type) {
|
||||
switch (type) {
|
||||
case ElementType::NONE:
|
||||
return "NONE";
|
||||
case ElementType::F32:
|
||||
return "F32";
|
||||
}
|
||||
llvm_unreachable("unsupported element type string");
|
||||
}
|
||||
|
||||
StringRef refbackrt::getArgTypeAsStringRef(ArgType type) {
|
||||
switch (type) {
|
||||
case ArgType::kNone:
|
||||
return "kNone";
|
||||
case ArgType::kTensor:
|
||||
return "kTensor";
|
||||
case ArgType::kF32:
|
||||
return "kF32";
|
||||
case ArgType::kF64:
|
||||
return "kF64";
|
||||
}
|
||||
llvm_unreachable("unsupported arg type string");
|
||||
}
|
||||
|
||||
Ref<Tensor> Tensor::create(ArrayRef<std::int32_t> extents, ElementType type,
|
||||
void *data) {
|
||||
return Ref<Tensor>(createRaw(extents, type, data));
|
||||
}
|
||||
|
||||
Tensor *Tensor::createRaw(ArrayRef<std::int32_t> extents, ElementType type,
|
||||
void *data) {
|
||||
auto *tensor = static_cast<Tensor *>(
|
||||
std::malloc(sizeof(Tensor) + extents.size() * sizeof(std::int32_t)));
|
||||
|
||||
tensor->refCount = 0;
|
||||
tensor->elementType = type;
|
||||
tensor->rank = extents.size();
|
||||
auto byteSize = getElementTypeByteSize(type) * totalElements(extents);
|
||||
// TODO: Align the buffer.
|
||||
tensor->allocatedPtr = std::malloc(byteSize);
|
||||
tensor->data = tensor->allocatedPtr;
|
||||
std::memcpy(tensor->data, data, byteSize);
|
||||
for (int i = 0, e = extents.size(); i < e; i++)
|
||||
tensor->getMutableExtents()[i] = extents[i];
|
||||
return tensor;
|
||||
}
|
||||
|
||||
std::int32_t Tensor::getDataByteSize() const {
|
||||
return getElementTypeByteSize(getElementType()) * totalElements(getExtents());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module metadata descriptors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename T> static void *ToVoidPtr(T *ptr) {
|
||||
return const_cast<void *>(static_cast<const void *>(ptr));
|
||||
}
|
||||
static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef name) {
|
||||
for (int i = 0, e = moduleDescriptor->numFuncDescriptors; i < e; i++) {
|
||||
auto &functionDescriptor = moduleDescriptor->functionDescriptors[i];
|
||||
if (StringRef(functionDescriptor.name, functionDescriptor.nameLen) ==
|
||||
name) {
|
||||
return &functionDescriptor;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef functionName, ArrayRef<RtValue> inputs,
|
||||
MutableArrayRef<RtValue> outputs) {
|
||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||
assert(descriptor && "unknown function name");
|
||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||
assert(outputs.size() < kMaxArity && "number of outputs exceeds kMaxArity");
|
||||
|
||||
// We haven't committed to using "vector" in this runtime code, so use
|
||||
// a fixed-sized array.
|
||||
std::array<UnrankedMemref, kMaxArity> inputUnrankedMemrefs;
|
||||
std::array<UnrankedMemref, kMaxArity> outputUnrankedMemrefs;
|
||||
std::array<void *, kMaxArity * 2> packedInputs;
|
||||
std::array<void *, kMaxArity> packedOutputs;
|
||||
|
||||
// Deepcopy the refbackrt::Tensor's into UnrankedMemref's.
|
||||
// TODO: Avoid the deep copy. It makes the later lifetime management code
|
||||
// more complex though (and maybe impossible given the current abstractions).
|
||||
//
|
||||
// Create a type-erased list of "packed inputs" to pass to the
|
||||
// LLVM/C ABI wrapper function. Each packedInput pointer corresponds to
|
||||
// one LLVM/C ABI argument to the underlying function.
|
||||
//
|
||||
// The ABI lowering on StandardToLLVM conversion side will
|
||||
// "explode" the unranked memref descriptors on the underlying function
|
||||
// into separate arguments for the rank and pointer-to-descriptor.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
auto idx = 2 * i;
|
||||
if (inputs[i].isTensor()) {
|
||||
inputUnrankedMemrefs[i] =
|
||||
convertRefbackrtTensorToUnrankedMemref(inputs[i].toTensor().get());
|
||||
packedInputs[idx] = ToVoidPtr(&inputUnrankedMemrefs[i].rank);
|
||||
packedInputs[idx + 1] = ToVoidPtr(&inputUnrankedMemrefs[i].descriptor);
|
||||
} else if (inputs[i].isScalar()) {
|
||||
packedInputs[idx] = ToVoidPtr(&inputs[i]);
|
||||
} else {
|
||||
assert(false && "unsupported input RtValue type");
|
||||
}
|
||||
}
|
||||
|
||||
// Create a type-erased list of "packed output" to pass to the
|
||||
// LLVM/C ABI wrapper function.
|
||||
//
|
||||
// Due to how StandardToLLVM lowering works, each packedOutput pointer
|
||||
// corresponds to a single UnrankedMemref (not "exploded").
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
if (outputs[i].isTensor()) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputUnrankedMemrefs[i]);
|
||||
} else if (outputs[i].isScalar()) {
|
||||
packedOutputs[i] = ToVoidPtr(&outputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Actually invoke the function!
|
||||
descriptor->functionPtr(packedInputs.data(), packedOutputs.data());
|
||||
|
||||
// Copy out the result data into refbackrt::Tensor's.
|
||||
// TODO: Avoid needing to make a deep copy.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
// TODO: Have compiler emit the element type in the metadata.
|
||||
if (outputs[i].isTensor()) {
|
||||
auto elementType = ElementType::F32;
|
||||
Tensor *tensor = convertUnrankedMemrefToRefbackrtTensor(
|
||||
outputUnrankedMemrefs[i].rank, outputUnrankedMemrefs[i].descriptor,
|
||||
elementType);
|
||||
outputs[i] = RtValue(Ref<Tensor>(tensor));
|
||||
} else if (outputs[i].isFloat()) {
|
||||
outputs[i] = RtValue(*(reinterpret_cast<float *>(packedOutputs[i])));
|
||||
}
|
||||
}
|
||||
|
||||
// Now, we just need to free all the UnrankedMemref's that we created.
|
||||
// This is complicated by the fact that multiple input/output UnrankedMemref's
|
||||
// can end up with the same backing buffer (`allocatedPtr`), and we need
|
||||
// to avoid double-freeing.
|
||||
// Output buffers might alias any other input or output buffer.
|
||||
// Input buffers are guaranteed to not alias each other.
|
||||
|
||||
// Free the output buffers.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
if (outputs[i].isRef()) {
|
||||
void *allocatedPtr = outputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
// Multiple returned memrefs can point into the same underlying
|
||||
// malloc allocation. Do a linear scan to see if any of the previously
|
||||
// deallocated buffers already freed this pointer.
|
||||
bool bufferNeedsFreeing = true;
|
||||
for (int j = 0; j < i; j++) {
|
||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||
bufferNeedsFreeing = false;
|
||||
}
|
||||
if (!bufferNeedsFreeing)
|
||||
std::free(allocatedPtr);
|
||||
}
|
||||
}
|
||||
|
||||
// Free the input buffers.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
if (!inputs[i].isRef())
|
||||
continue;
|
||||
void *allocatedPtr = inputUnrankedMemrefs[i].descriptor->allocatedPtr;
|
||||
bool bufferNeedsFreeing = true;
|
||||
for (int j = 0, je = outputs.size(); j < je; j++) {
|
||||
if (!outputs[j].isRef())
|
||||
continue;
|
||||
if (allocatedPtr == outputUnrankedMemrefs[j].descriptor->allocatedPtr)
|
||||
bufferNeedsFreeing = false;
|
||||
}
|
||||
// HACK: The returned memref can point into statically allocated memory that
|
||||
// we can't pass to `free`, such as the result of lowering a tensor-valued
|
||||
// `std.constant` to `std.global_memref`. The LLVM lowering of
|
||||
// std.global_memref sets the allocated pointer to the magic value
|
||||
// 0xDEADBEEF, which we sniff for here. This is yet another strong signal
|
||||
// that memref is really not the right abstraction for ABI's.
|
||||
if (reinterpret_cast<std::intptr_t>(allocatedPtr) == 0xDEADBEEF)
|
||||
bufferNeedsFreeing = false;
|
||||
if (!bufferNeedsFreeing)
|
||||
std::free(allocatedPtr);
|
||||
}
|
||||
|
||||
// Free the output descriptors.
|
||||
for (int i = 0, e = outputs.size(); i < e; i++) {
|
||||
if (!outputs[i].isRef())
|
||||
continue;
|
||||
// The LLVM lowering guarantees that each returned unranked memref
|
||||
// descriptor is separately malloc'ed, so no need to do anything special
|
||||
// like we had to do for the allocatedPtr's.
|
||||
std::free(outputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
// Free the input descriptors.
|
||||
for (int i = 0, e = inputs.size(); i < e; i++) {
|
||||
if (!inputs[i].isRef())
|
||||
continue;
|
||||
std::free(inputUnrankedMemrefs[i].descriptor);
|
||||
}
|
||||
}
|
||||
|
||||
static InputArgInfo
|
||||
getExternalInputArgInfo(const refbackrt::InputDescriptor &inputDescriptor) {
|
||||
InputArgInfo ret;
|
||||
|
||||
// Set arg / element types accordingly
|
||||
switch (inputDescriptor.abiType) {
|
||||
case ABIArgType::kNone:
|
||||
ret.argType = ArgType::kNone;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kMemref:
|
||||
ret.argType = ArgType::kTensor;
|
||||
ret.elementType = ElementType::F32;
|
||||
break;
|
||||
case ABIArgType::kF32:
|
||||
ret.argType = ArgType::kF32;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kF64:
|
||||
ret.argType = ArgType::kF64;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
}
|
||||
|
||||
// Extract shape information
|
||||
ret.rank = inputDescriptor.rank;
|
||||
for (int i = 0; i < inputDescriptor.rank; i++) {
|
||||
ret.extents[i] = inputDescriptor.extents[i];
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static OutputArgInfo
|
||||
getExternalOutputArgInfo(const refbackrt::OutputDescriptor &outputDescriptor) {
|
||||
OutputArgInfo ret;
|
||||
|
||||
// Set arg / element types accordingly
|
||||
switch (outputDescriptor.abiType) {
|
||||
case ABIArgType::kNone:
|
||||
ret.argType = ArgType::kNone;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kMemref:
|
||||
ret.argType = ArgType::kTensor;
|
||||
ret.elementType = ElementType::F32;
|
||||
break;
|
||||
case ABIArgType::kF32:
|
||||
ret.argType = ArgType::kF32;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
case ABIArgType::kF64:
|
||||
ret.argType = ArgType::kF64;
|
||||
ret.elementType = ElementType::NONE;
|
||||
break;
|
||||
}
|
||||
|
||||
// Extract shape information
|
||||
ret.rank = outputDescriptor.rank;
|
||||
for (int i = 0; i < outputDescriptor.rank; i++) {
|
||||
ret.extents[i] = outputDescriptor.extents[i];
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||
StringRef functionName,
|
||||
FunctionMetadata &outMetadata) {
|
||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||
if (!descriptor)
|
||||
return failure();
|
||||
outMetadata.numInputs = descriptor->numInputs;
|
||||
outMetadata.numOutputs = descriptor->numOutputs;
|
||||
|
||||
for (int i = 0; i < descriptor->numInputs; i++) {
|
||||
outMetadata.inputArgInfos[i] =
|
||||
getExternalInputArgInfo(descriptor->inputDescriptors[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i < descriptor->numOutputs; i++) {
|
||||
outMetadata.outputArgInfos[i] =
|
||||
getExternalOutputArgInfo(descriptor->outputDescriptors[i]);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::checkRtValueShapes(const RtValue &value,
|
||||
const InputArgInfo &info) {
|
||||
if (value.isTensor()) {
|
||||
auto refTensor = value.toTensor();
|
||||
|
||||
// Don't bother checking shapes for unranked tensors
|
||||
if (info.rank < 0)
|
||||
return success();
|
||||
|
||||
if (refTensor->getRank() != info.rank)
|
||||
return failure();
|
||||
|
||||
auto tensorExtents = refTensor->getExtents();
|
||||
for (int i = 0; i < info.rank; i++) {
|
||||
// If a dimension is dynamic, it is encoded as extent = -1
|
||||
// and we should skip checking over that dimension
|
||||
if (info.extents[i] > 0 && (info.extents[i] != tensorExtents[i]))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult refbackrt::checkRtValueArgTypes(const RtValue &value,
|
||||
const InputArgInfo &info) {
|
||||
// Generic checks based on argType(s)
|
||||
if ((value.isTensor() && info.argType != ArgType::kTensor) ||
|
||||
(value.isFloat() && info.argType != ArgType::kF32))
|
||||
return failure();
|
||||
|
||||
if (value.isRef()) {
|
||||
// Will need special error checking for ref-counted types
|
||||
// Currently only f32 tensors are supported
|
||||
if (value.isTensor()) {
|
||||
auto refTensor = value.toTensor();
|
||||
if (refTensor->getElementType() != ElementType::F32)
|
||||
return failure();
|
||||
} else {
|
||||
assert(false && "Unsupported input type checking for Ref type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
RtValue refbackrt::createRtValueFromOutputArgInfo(const OutputArgInfo &info) {
|
||||
constexpr int32_t kDynamicConstantShape = 100;
|
||||
switch (info.argType) {
|
||||
case ArgType::kTensor: {
|
||||
// HACK: for dynamic dims the shape will be negative, so for now we are
|
||||
// just going to create a tensor of size kDynamicConstantShape
|
||||
std::array<int32_t, kMaxRank> tensorShape;
|
||||
for (int i = 0; i < info.rank; i++) {
|
||||
tensorShape[i] =
|
||||
info.extents[i] > 0 ? info.extents[i] : kDynamicConstantShape;
|
||||
}
|
||||
refbackrt::ArrayRef<int32_t> shape(tensorShape.data(), info.rank);
|
||||
int numel = 1;
|
||||
for (int i = 0; i < info.rank; i++)
|
||||
numel *= shape[i];
|
||||
|
||||
void *data;
|
||||
switch (info.elementType) {
|
||||
case ElementType::F32: {
|
||||
auto byteSize = numel * sizeof(float);
|
||||
data = static_cast<void *>(malloc(byteSize));
|
||||
assert(data && "could not allocate tensor");
|
||||
memset(data, 0, byteSize);
|
||||
return RtValue(Tensor::create(shape, ElementType::F32, data));
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
assert(false && "unknown output tensor type");
|
||||
return RtValue();
|
||||
}
|
||||
}
|
||||
|
||||
// The Tensor::create function will malloc and memcpy the data
|
||||
// into the Tensor object, so after we need to free our
|
||||
// temporary data buffer
|
||||
assert(data && "data ptr must exist");
|
||||
auto refTensor = Tensor::create(shape, ElementType::F32, data);
|
||||
free(data);
|
||||
return RtValue(refTensor);
|
||||
}
|
||||
case ArgType::kF32: {
|
||||
return RtValue(-20.0f);
|
||||
}
|
||||
default: {
|
||||
assert(false && "Don't know how to handle this artType");
|
||||
return RtValue();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +0,0 @@
|
|||
{
|
||||
global: __npcomp_compiler_rt_*;
|
||||
local: *;
|
||||
};
|
||||
|
|
@ -3,25 +3,6 @@ include(MLIRDetectPythonEnv)
|
|||
|
||||
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=npcomp.")
|
||||
|
||||
################################################################################
|
||||
# Resources that must be packaged into the python tree
|
||||
################################################################################
|
||||
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/npcomp/compiler/backend/refjit_resources")
|
||||
add_custom_target(NPCOMPPythonResources ALL)
|
||||
add_custom_command(
|
||||
TARGET NPCOMPPythonResources
|
||||
COMMAND ${CMAKE_COMMAND} -E copy
|
||||
# TODO: Make the runtime library work for windows.
|
||||
# TODO: Use $<TARGET-FILE:> for this.
|
||||
${CMAKE_BINARY_DIR}/lib/libNPCOMPCompilerRuntimeShlib${CMAKE_SHARED_LIBRARY_SUFFIX}
|
||||
${MLIR_NPCOMP_PYTHON_PACKAGES_DIR}/npcomp_core/npcomp/compiler/generic/backend/libNPCOMPCompilerRuntimeShlib${CMAKE_SHARED_LIBRARY_SUFFIX}
|
||||
)
|
||||
add_dependencies(NPCOMPPythonResources
|
||||
NPCOMPCompilerRuntimeShlib
|
||||
)
|
||||
|
||||
|
||||
################################################################################
|
||||
# Declare sources
|
||||
################################################################################
|
||||
|
@ -50,17 +31,11 @@ declare_mlir_python_sources(NPCOMPPythonCAPIHeaderSources
|
|||
# Extensions
|
||||
################################################################################
|
||||
|
||||
set(_addl_extension_sources)
|
||||
if(NPCOMP_ENABLE_REFJIT)
|
||||
list(APPEND _addl_extension_sources "${CMAKE_CURRENT_SOURCE_DIR}/RefJITBackend.cpp")
|
||||
endif()
|
||||
|
||||
declare_mlir_python_extension(NPCOMPPythonExtensions.Core
|
||||
MODULE_NAME _npcomp
|
||||
ADD_TO_PARENT NPCOMPPythonExtensions
|
||||
SOURCES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/NpcompModule.cpp
|
||||
${_addl_extension_sources}
|
||||
EMBED_CAPI_LINK_LIBS
|
||||
NPCOMPCAPI
|
||||
PRIVATE_LINK_LIBS
|
||||
|
@ -109,9 +84,6 @@ add_mlir_python_modules(NPCOMPPythonModules
|
|||
COMMON_CAPI_LINK_LIBS
|
||||
NPCOMPPythonCAPI
|
||||
)
|
||||
add_dependencies(NPCOMPPythonModules
|
||||
NPCOMPPythonResources
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Torch support libraries.
|
||||
|
|
|
@ -38,10 +38,4 @@ PYBIND11_MODULE(_npcomp, m) {
|
|||
// Optional backend modules.
|
||||
auto backend_m = m.def_submodule("backend", "Backend support");
|
||||
(void)backend_m;
|
||||
|
||||
#ifdef NPCOMP_ENABLE_REFJIT
|
||||
auto refjit_m =
|
||||
backend_m.def_submodule("refjit", "Reference CPU Jit Backend");
|
||||
::npcomp::python::defineBackendRefJitModule(refjit_m);
|
||||
#endif
|
||||
}
|
||||
|
|
|
@ -1,164 +0,0 @@
|
|||
//===- PythonModule.cpp - RefJIT python bindings --------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "./NpcompModule.h"
|
||||
|
||||
#include <cstdlib>
|
||||
|
||||
#include "pybind11/numpy.h"
|
||||
|
||||
#include "npcomp-c/RefJITBackend.h"
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
using llvm::SmallVector;
|
||||
using llvm::StringRef;
|
||||
using llvm::Twine;
|
||||
|
||||
static NpcompRefJitElementType
|
||||
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
|
||||
if (format == "f")
|
||||
return NPCOMP_REFJIT_F32;
|
||||
|
||||
std::string message("unsupported buffer format: ");
|
||||
message.append(format);
|
||||
throw py::raiseValueError(message);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct PyRefJitModule {
|
||||
PyRefJitModule(NpcompRefJitModule instance) : instance(instance) {}
|
||||
~PyRefJitModule() {
|
||||
if (instance.ptr)
|
||||
npcompRefJitModuleDestroy(instance);
|
||||
}
|
||||
PyRefJitModule(const PyRefJitModule &) = delete;
|
||||
void operator=(const PyRefJitModule &) = delete;
|
||||
PyRefJitModule(PyRefJitModule &&other) : instance(other.instance) {
|
||||
other.instance.ptr = nullptr;
|
||||
}
|
||||
|
||||
NpcompRefJitModule instance = {nullptr};
|
||||
};
|
||||
|
||||
struct PyRefJitValueList {
|
||||
PyRefJitValueList(NpcompRefJitValueList instance) : instance(instance) {}
|
||||
~PyRefJitValueList() {
|
||||
if (instance.ptr)
|
||||
npcompRefJitValueListDestroy(instance);
|
||||
}
|
||||
PyRefJitValueList(const PyRefJitValueList &) = delete;
|
||||
void operator=(const PyRefJitValueList &) = delete;
|
||||
PyRefJitValueList(PyRefJitValueList &&other) : instance(other.instance) {
|
||||
other.instance.ptr = nullptr;
|
||||
}
|
||||
|
||||
NpcompRefJitValueList instance = {nullptr};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void npcomp::python::defineBackendRefJitModule(py::module &m) {
|
||||
m.def("build_backend_compilation_pipeline", [](MlirPassManager capiPm) {
|
||||
npcompRefJitBuildBackendCompilationPipeline(capiPm, /*optimize=*/true);
|
||||
});
|
||||
py::class_<PyRefJitValueList>(m, "ValueList");
|
||||
py::class_<PyRefJitModule>(m, "JITModule")
|
||||
.def_static(
|
||||
"from_compiled_module",
|
||||
[](MlirModule capiModule,
|
||||
std::vector<std::string> pySharedLibs) -> PyRefJitModule {
|
||||
SmallVector<MlirStringRef, 4> sharedLibs;
|
||||
for (auto &s : pySharedLibs)
|
||||
sharedLibs.push_back(MlirStringRef{s.data(), s.size()});
|
||||
char *errorMessageCstr;
|
||||
NpcompRefJitModule m =
|
||||
npcompRefJitModuleCreate(capiModule, &sharedLibs[0],
|
||||
sharedLibs.size(), &errorMessageCstr);
|
||||
if (npcompRefJitModuleIsNull(m)) {
|
||||
std::string errorMessage(errorMessageCstr);
|
||||
std::free(errorMessageCstr);
|
||||
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
||||
}
|
||||
return PyRefJitModule(m);
|
||||
},
|
||||
py::arg("module"), py::arg("shared_libs"))
|
||||
.def(
|
||||
"invoke",
|
||||
[](PyRefJitModule &self, std::string functionName,
|
||||
std::vector<py::buffer> inputs) {
|
||||
py::object ioListObject =
|
||||
py::cast(PyRefJitValueList(npcompRefJitValueListCreate()));
|
||||
PyRefJitValueList &ioList =
|
||||
py::cast<PyRefJitValueList &>(ioListObject);
|
||||
|
||||
// Prepare inputs.
|
||||
for (auto &buffer : inputs) {
|
||||
// Request a C contiguous view as that is what Tensor accepts now
|
||||
// (no strides or non row-major layout).
|
||||
int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
|
||||
std::unique_ptr<Py_buffer> view(new Py_buffer());
|
||||
if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) {
|
||||
throw py::error_already_set();
|
||||
}
|
||||
py::buffer_info info(view.release());
|
||||
auto elementType =
|
||||
mapBufferFormatToElementType(info.format, info.itemsize);
|
||||
SmallVector<int32_t, 4> extents(info.shape.begin(),
|
||||
info.shape.end());
|
||||
|
||||
npcompRefJitValueAddTensorCopy(ioList.instance, elementType,
|
||||
extents.data(), extents.size(),
|
||||
info.ptr);
|
||||
}
|
||||
|
||||
// Invoke.
|
||||
char *errorMessageCstr;
|
||||
if (!npcompRefJitModuleInvoke(
|
||||
self.instance,
|
||||
MlirStringRef{functionName.data(), functionName.size()},
|
||||
ioList.instance, &errorMessageCstr)) {
|
||||
std::string errorMessage(errorMessageCstr);
|
||||
std::free(errorMessageCstr);
|
||||
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
||||
}
|
||||
|
||||
// Prepare outputs.
|
||||
std::vector<py::object> outputs;
|
||||
for (intptr_t i = 0; i < npcompRefJitValueListSize(ioList.instance);
|
||||
++i) {
|
||||
if (npcompRefJitValueIsaTensor(ioList.instance, i)) {
|
||||
NpcompRefJitElementType elementType;
|
||||
intptr_t rank;
|
||||
const int32_t *extents;
|
||||
void *data = npcompRefJitValueGetTensor(
|
||||
ioList.instance, i, &elementType, &rank, &extents);
|
||||
|
||||
const char *format;
|
||||
switch (elementType) {
|
||||
case NPCOMP_REFJIT_F32:
|
||||
format = "f";
|
||||
break;
|
||||
default:
|
||||
throw py::raiseValueError("unsupported tensor element type");
|
||||
}
|
||||
|
||||
outputs.push_back(
|
||||
py::array(py::dtype(format),
|
||||
llvm::ArrayRef<std::int32_t>(extents, rank), data,
|
||||
/*base=*/ioListObject));
|
||||
} else {
|
||||
throw py::raisePyError(PyExc_ValueError,
|
||||
"unsupported npcomp refjit return type");
|
||||
}
|
||||
}
|
||||
return outputs;
|
||||
},
|
||||
py::arg("function_name"), py::arg("inputs"));
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
"""Exports configuration for the package and settings for building libraries."""
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
from ._mlir_libs import get_include_dirs, get_lib_dirs
|
||||
|
||||
__all__ = [
|
||||
"get_include_dirs",
|
||||
"get_lib_dirs",
|
||||
]
|
||||
|
||||
|
||||
def get_capi_link_library_name() -> str:
|
||||
"""Gets the library name of the CAPI shared library to link against."""
|
||||
return "NPCOMPPythonCAPI"
|
||||
|
||||
|
||||
def get_capi_link_library_path() -> str:
|
||||
"""Returns an absolute path to the CAPI shared library.
|
||||
|
||||
This should be preferred when seeking to create a non relocatable package
|
||||
as it eliminates the possibility of interference of similar named libraries
|
||||
on the link path.
|
||||
|
||||
Raises:
|
||||
ValueError: If the library cannot be found.
|
||||
"""
|
||||
system = platform.system()
|
||||
lib_prefix = "lib"
|
||||
lib_suffix = ".so"
|
||||
if system == "Darwin":
|
||||
lib_suffix = ".dylib"
|
||||
elif system == "Windows":
|
||||
lib_prefix = ""
|
||||
lib_suffix = ".lib"
|
||||
lib_filename = f"{lib_prefix}{get_capi_link_library_name()}{lib_suffix}"
|
||||
|
||||
for lib_dir in get_lib_dirs():
|
||||
full_path = os.path.join(lib_dir, lib_filename)
|
||||
if os.path.exists(full_path): return full_path
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to find npcomp development library {lib_filename} in "
|
||||
f"{get_lib_dirs()}")
|
||||
|
|
@ -1,66 +0,0 @@
|
|||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
import os
|
||||
import platform
|
||||
|
||||
_refjit = None
|
||||
|
||||
|
||||
def get_refjit():
|
||||
"""Dynamically resolves the refjit backend native module."""
|
||||
global _refjit
|
||||
if _refjit is not None:
|
||||
return _refjit
|
||||
from ...._mlir_libs import _npcomp as _cext
|
||||
try:
|
||||
imported_refjit = _cext.backend.refjit
|
||||
except AttributeError:
|
||||
raise ImportError(
|
||||
"The npcomp native module was not compiled with refjit support")
|
||||
_refjit = imported_refjit
|
||||
return _refjit
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
"""Returns whether the backend is enabled for the current build."""
|
||||
try:
|
||||
_get_refjit()
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def get_runtime_libs():
|
||||
# The _refjit_resources directory is at the npcomp.compiler level.
|
||||
resources_dir = os.path.join(os.path.dirname(__file__))
|
||||
suffix = ".so"
|
||||
if platform.system() == "Darwin":
|
||||
suffix = ".dylib"
|
||||
shlib_name = f"libNPCOMPCompilerRuntimeShlib{suffix}"
|
||||
return [os.path.join(resources_dir, shlib_name)]
|
||||
|
||||
|
||||
class JitModuleInvoker:
|
||||
"""Wrapper around a native JitModule for calling functions."""
|
||||
|
||||
def __init__(self, jit_module):
|
||||
super().__init__()
|
||||
self._jit_module = jit_module
|
||||
|
||||
def __getattr__(self, function_name):
|
||||
return self.__getitem__(function_name)
|
||||
|
||||
def __getitem__(self, function_name):
|
||||
|
||||
def invoke(*args):
|
||||
results = self._jit_module.invoke(function_name, args)
|
||||
if len(results) == 1:
|
||||
# De-tuple.
|
||||
return results[0]
|
||||
else:
|
||||
return tuple(results)
|
||||
|
||||
invoke.__isnpcomp__ = True
|
||||
return invoke
|
|
@ -8,7 +8,6 @@ configure_lit_site_cfg(
|
|||
set(NPCOMP_TEST_DEPENDS
|
||||
FileCheck count not
|
||||
npcomp-opt
|
||||
refback-run
|
||||
NPCOMPPythonModules
|
||||
)
|
||||
|
||||
|
|
|
@ -1,8 +0,0 @@
|
|||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @alloc_memref
|
||||
func @alloc_memref(%arg0: tensor<?xindex>) {
|
||||
// CHECK: refback.alloc_memref
|
||||
%0 = refback.alloc_memref %arg0 : memref<?xf32>
|
||||
return
|
||||
}
|
|
@ -1,23 +0,0 @@
|
|||
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
||||
|
||||
refbackrt.module_metadata {
|
||||
// expected-error @+1 {{must reference a valid func}}
|
||||
refbackrt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
refbackrt.module_metadata {
|
||||
// expected-error @+1 {{must agree on number of inputs}}
|
||||
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||
}
|
||||
func @f() { return }
|
||||
|
||||
// -----
|
||||
|
||||
refbackrt.module_metadata {
|
||||
// expected-error @+1 {{must agree on number of outputs}}
|
||||
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
||||
}
|
||||
func @f() { return }
|
|
@ -1,19 +0,0 @@
|
|||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: refbackrt.module_metadata
|
||||
refbackrt.module_metadata {
|
||||
// CHECK: refbackrt.func_metadata
|
||||
// TODO(brycearden): Encode unranked memrefs in the ABI
|
||||
refbackrt.func_metadata {
|
||||
funcName = @f,
|
||||
numInputs = 1 : i32,
|
||||
numOutputs = 0 : i32,
|
||||
inputArgTypes = dense<1> : tensor<1xi32>,
|
||||
inputElementTypes = dense<1> : tensor<1xi32>,
|
||||
inputRanks = dense<-1> : tensor<1xi32>,
|
||||
inputShapes = dense<1> : tensor<4xi32>}
|
||||
}
|
||||
|
||||
func @f(%arg0: tensor<*xf32>) {
|
||||
return
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
// RUN: npcomp-opt -split-input-file -lower-alloc-memref-ops <%s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @basic
|
||||
func @basic(%arg0: tensor<?xindex>) -> memref<?xf32> {
|
||||
// CHECK: %[[I:.*]] = constant 0 : index
|
||||
// CHECK: %[[E:.*]] = tensor.extract %arg0[%[[I]]]
|
||||
// CHECK: alloc(%[[E]])
|
||||
%0 = refback.alloc_memref %arg0 : memref<?xf32>
|
||||
return %0 : memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @all_static
|
||||
func @all_static(%arg0: tensor<?xindex>) -> memref<3x4x5xf32> {
|
||||
// CHECK-NOT: tensor.extract
|
||||
// CHECK: alloc()
|
||||
%0 = refback.alloc_memref %arg0 : memref<3x4x5xf32>
|
||||
return %0 : memref<3x4x5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @some_static
|
||||
func @some_static(%arg0: tensor<?xindex>) -> memref<3x?x5x?x7xf32> {
|
||||
// CHECK-DAG: %[[I1:.*]] = constant 1 : index
|
||||
// CHECK-DAG: %[[E1:.*]] = tensor.extract %arg0[%[[I1]]]
|
||||
// CHECK-DAG: %[[I3:.*]] = constant 3 : index
|
||||
// CHECK-DAG: %[[E3:.*]] = tensor.extract %arg0[%[[I3]]]
|
||||
// CHECK: alloc(%[[E1]], %[[E3]])
|
||||
%0 = refback.alloc_memref %arg0 : memref<3x?x5x?x7xf32>
|
||||
return %0 : memref<3x?x5x?x7xf32>
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
// RUN: npcomp-opt -lower-to-refbackrt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// Test module metadata.
|
||||
|
||||
// CHECK: refbackrt.module_metadata
|
||||
// CHECK-NEXT: refbackrt.func_metadata
|
||||
// CHECK-SAME: funcName = @f_2inputs_0outputs
|
||||
// CHECK-SAME: numInputs = 2
|
||||
// CHECK-SAME: numOutputs = 0
|
||||
// CHECK-NEXT: refbackrt.func_metadata
|
||||
// CHECK-SAME: funcName = @f_1input_2outputs
|
||||
// CHECK-SAME: numInputs = 1
|
||||
// CHECK-SAME: numOutputs = 2
|
||||
|
||||
// This function only exists to test its metadata above.
|
||||
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
||||
return
|
||||
}
|
||||
|
||||
// This function only exists to test its metadata above.
|
||||
func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>) {
|
||||
return %arg0, %arg0 : memref<?xf32>, memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test ABI conversions.
|
||||
|
||||
// CHECK-LABEL: func @identity(%arg0: memref<*xf32>) -> memref<*xf32>
|
||||
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK: return %arg0
|
||||
return %arg0 : memref<?xf32>
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @use_of_arg(%arg0: memref<*xf32>)
|
||||
func @use_of_arg(%arg0: memref<?xf32>) {
|
||||
// CHECK-NEXT: %[[MEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
%c0 = constant 0 : index
|
||||
%0 = memref.dim %arg0, %c0 : memref<?xf32>
|
||||
// CHECK-NEXT: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-NEXT: memref.dim %[[MEMREF]], %[[C0]] : memref<?xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @multiple_blocks(%arg0: memref<*xf32>) -> memref<*xf32>
|
||||
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||
// CHECK-NEXT: %[[INMEMREF:.*]] = memref.cast %arg0 : memref<*xf32> to memref<?xf32>
|
||||
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
|
||||
br ^bb1(%arg0: memref<?xf32>)
|
||||
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
||||
^bb1(%bbarg: memref<?xf32>):
|
||||
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref.cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
||||
// CHECK-NEXT: return %[[OUTMEMREF]] : memref<*xf32>
|
||||
return %bbarg : memref<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test diagnostics.
|
||||
|
||||
// expected-error @+1 {{func not expressible with refbackrt ABI}}
|
||||
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
||||
return
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
# RUN: %PYTHON %s 2>&1
|
||||
|
||||
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
# See https://llvm.org/LICENSE.txt for license information.
|
||||
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
|
||||
from npcomp import build
|
||||
|
||||
assert build.get_include_dirs()
|
||||
assert build.get_lib_dirs()
|
||||
print("CAPI Path:", build.get_capi_link_library_path())
|
|
@ -49,9 +49,6 @@ config.test_source_root = os.path.dirname(__file__)
|
|||
# test_exec_root: The root path where tests should be run.
|
||||
config.test_exec_root = os.path.join(config.npcomp_obj_root, 'test')
|
||||
config.npcomp_bin_dir = os.path.join(config.npcomp_obj_root, 'bin')
|
||||
config.npcomp_runtime_shlib = os.path.join(
|
||||
config.npcomp_obj_root, 'lib',
|
||||
'libNPCOMPCompilerRuntimeShlib' + config.llvm_shlib_ext)
|
||||
|
||||
# Tweak the PATH and PYTHONPATH to include the tools dir.
|
||||
npcomp_python_dir = "python" if config.npcomp_built_standalone else "tools/npcomp/python"
|
||||
|
@ -67,8 +64,6 @@ tool_dirs = [
|
|||
]
|
||||
tools = [
|
||||
'npcomp-opt',
|
||||
'refback-run',
|
||||
ToolSubst('%npcomp_runtime_shlib', config.npcomp_runtime_shlib),
|
||||
]
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
|
|
@ -1,27 +0,0 @@
|
|||
// RUN: refback-run %s \
|
||||
// RUN: -invoke forward \
|
||||
// RUN: -arg-value="dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32>" \
|
||||
// RUN: -arg-value="dense<[10.0, 20.0]> : tensor<2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK{LITERAL}: output #0: dense<[[1.100000e+01, 2.200000e+01], [1.300000e+01, 2.400000e+01]]> : tensor<2x2xf32>
|
||||
#map0 = affine_map<(d0, d1) -> (d0, d1)>
|
||||
#map1 = affine_map<(d0, d1) -> (d1)>
|
||||
|
||||
builtin.func @forward(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?x?xf32> {
|
||||
%c1 = constant 1 : index
|
||||
%c0 = constant 0 : index
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
|
||||
%1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
|
||||
%2 = tensor.dim %arg1, %c0 : tensor<?xf32>
|
||||
%3 = cmpi eq, %1, %2 : index
|
||||
assert %3, "mismatched size for broadcast"
|
||||
%4 = linalg.init_tensor [%0, %1] : tensor<?x?xf32>
|
||||
%5 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types = ["parallel", "parallel"]} ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?xf32>) outs(%4 : tensor<?x?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
|
||||
%6 = addf %arg2, %arg3 : f32
|
||||
linalg.yield %6 : f32
|
||||
} -> tensor<?x?xf32>
|
||||
return %5 : tensor<?x?xf32>
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// RUN: refback-run %s \
|
||||
// RUN: -invoke constant \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
|
||||
func @constant() -> tensor<f32> {
|
||||
%0 = constant dense<1.0> : tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// RUN: refback-run %s \
|
||||
// RUN: -invoke identity \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<f32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
|
||||
func @identity(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
return %arg0 : tensor<f32>
|
||||
}
|
|
@ -1,12 +0,0 @@
|
|||
// RUN: not refback-run %s \
|
||||
// RUN: -invoke expects_one_tensor \
|
||||
// RUN: -arg-value="1.0 : f32" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: invoking 'expects_one_tensor': input argument type mismatch.
|
||||
// CHECK-SAME: actual (provided by user): Float
|
||||
// CHECK-SAME: expected (from compiler): kTensor
|
||||
func @expects_one_tensor(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
return %arg0 : tensor<?xf32>
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
// RUN: not refback-run %s \
|
||||
// RUN: -invoke requires_one_input \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: invoking 'requires_one_input': expected 1 inputs
|
||||
func @requires_one_input(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
return %arg0 : tensor<?xf32>
|
||||
}
|
|
@ -1,20 +0,0 @@
|
|||
// RUN: refback-run %s \
|
||||
// RUN: -invoke multi_output \
|
||||
// RUN: -arg-value="dense<1.0> : tensor<1xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<2.000000e+00> : tensor<1xf32>
|
||||
// CHECK: output #1: dense<2.000000e+00> : tensor<1xf32>
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
func @multi_output(%arg0: tensor<?xf32>) -> (tensor<?xf32>, tensor<?xf32>) {
|
||||
%c0 = constant 0 : index
|
||||
%0 = tensor.dim %arg0, %c0 : tensor<?xf32>
|
||||
%1 = linalg.init_tensor [%0] : tensor<?xf32>
|
||||
%2 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} ins(%arg0, %arg0 : tensor<?xf32>, tensor<?xf32>) outs(%1 : tensor<?xf32>) {
|
||||
^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors
|
||||
%6 = addf %arg2, %arg3 : f32
|
||||
linalg.yield %6 : f32
|
||||
} -> tensor<?xf32>
|
||||
return %2, %2 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
|
@ -1,10 +0,0 @@
|
|||
// RUN: refback-run %s \
|
||||
// RUN: -invoke scalar_arg \
|
||||
// RUN: -arg-value="2.5 : f32" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: 2.500000e+00 : f32
|
||||
func @scalar_arg(%arg0: f32) -> f32 {
|
||||
return %arg0 : f32
|
||||
}
|
|
@ -1,3 +1,2 @@
|
|||
add_subdirectory(npcomp-lsp-server)
|
||||
add_subdirectory(npcomp-opt)
|
||||
add_subdirectory(refback-run)
|
||||
|
|
|
@ -18,19 +18,6 @@ npcomp-opt() {
|
|||
"$@"
|
||||
}
|
||||
|
||||
refback-run() {
|
||||
# Helper for building and invoking refback-run.
|
||||
#
|
||||
# This also automatically builds and adds the npcomp runtime shared
|
||||
# library.
|
||||
#
|
||||
# Usage:
|
||||
# $ refback-run <regular refback-run options>
|
||||
ninja -C "$build_dir" refback-run NPCOMPCompilerRuntimeShlib 1>&2 || return 1
|
||||
"$build_dir/bin/refback-run" \
|
||||
-shared-libs="${build_dir}/lib/libNPCOMPCompilerRuntimeShlib.so" "$@"
|
||||
}
|
||||
|
||||
# Go to the root of your npcomp checkout.
|
||||
alias npd="cd $td"
|
||||
# Handy so that `npctest -v` runs lit with -v and thus prints out errors,
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
# refback-run is always linked dynamically as we want to distribute the
|
||||
# binaries with the python packages for hacking/debugging.
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY NPCOMP_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY NPCOMP_CONVERSION_LIBS)
|
||||
|
||||
add_npcomp_executable(refback-run
|
||||
refback-run.cpp
|
||||
)
|
||||
|
||||
llvm_update_compile_flags(refback-run)
|
||||
target_link_libraries(refback-run PRIVATE
|
||||
NPCOMPCAPI
|
||||
NPCOMPInitAll
|
||||
MLIRAnalysis
|
||||
MLIRIR
|
||||
MLIRJitRunner
|
||||
MLIRParser
|
||||
MLIRSupport
|
||||
NPCOMPInitAll
|
||||
NPCOMPRefBackendJITHelpers
|
||||
TorchMLIRInitAll
|
||||
|
||||
# TODO: Remove these in favor of interface deps.
|
||||
${conversion_libs}
|
||||
${dialect_libs}
|
||||
)
|
||||
add_dependencies(refback-run
|
||||
NPCOMPCompilerRuntimeShlib
|
||||
)
|
|
@ -1,252 +0,0 @@
|
|||
//===------------------------------------------------------------*- C++ -*-===//
|
||||
//
|
||||
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Utility binary for compiling and running code through the npcomp
|
||||
// RefBackend. The accepted input is the npcomp backend contract
|
||||
// (roughly, linalg-on-tensors + std).
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/IR/AsmState.h"
|
||||
#include "mlir/InitAllDialects.h"
|
||||
#include "mlir/InitAllPasses.h"
|
||||
#include "mlir/Parser.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "npcomp-c/InitLLVM.h"
|
||||
#include "npcomp/InitAll.h"
|
||||
#include "npcomp/RefBackend/JITHelpers/JITModule.h"
|
||||
#include "torch-mlir/InitAll.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
|
||||
using namespace mlir;
|
||||
using llvm::Error;
|
||||
using llvm::Expected;
|
||||
using llvm::StringError;
|
||||
using llvm::Twine;
|
||||
|
||||
/// Wrap a string into an llvm::StringError.
|
||||
static Error make_string_error(const Twine &message) {
|
||||
return llvm::make_error<StringError>(message.str(),
|
||||
llvm::inconvertibleErrorCode());
|
||||
}
|
||||
|
||||
static Expected<refbackrt::Ref<refbackrt::Tensor>>
|
||||
convertAttrToTensor(Attribute attr) {
|
||||
auto type = attr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return make_string_error("unhandled argument type; must be a tensor type");
|
||||
auto extents = llvm::to_vector<6>(llvm::map_range(
|
||||
type.getShape(), [](int64_t x) { return static_cast<std::int32_t>(x); }));
|
||||
auto elementType = type.getElementType();
|
||||
auto denseFp = attr.dyn_cast<DenseFPElementsAttr>();
|
||||
if (denseFp) {
|
||||
if (elementType.isF32()) {
|
||||
auto values = llvm::to_vector<100>(llvm::map_range(
|
||||
denseFp, [](APFloat f) { return f.convertToFloat(); }));
|
||||
return refbackrt::Tensor::create(
|
||||
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
||||
refbackrt::ElementType::F32, static_cast<void *>(values.data()));
|
||||
}
|
||||
} else {
|
||||
return make_string_error("unhandled argument; must be dense floating-point");
|
||||
}
|
||||
return make_string_error("unhandled argument");
|
||||
}
|
||||
|
||||
static Expected<float> convertAttrToFloat(Attribute attr) {
|
||||
auto type = attr.getType().dyn_cast<FloatType>();
|
||||
if (!type)
|
||||
return make_string_error("converting an argument to float that is not a FloatType");
|
||||
auto floatAttr = attr.dyn_cast<FloatAttr>();
|
||||
return floatAttr.getValue().convertToFloat();
|
||||
}
|
||||
|
||||
static Expected<SmallVector<refbackrt::RtValue, 6>>
|
||||
createInputs(ArrayRef<StringRef> argValues) {
|
||||
MLIRContext context;
|
||||
SmallVector<refbackrt::RtValue, 6> ret;
|
||||
for (auto argValue : argValues) {
|
||||
auto attr = parseAttribute(argValue, &context);
|
||||
if (!attr)
|
||||
return make_string_error(Twine("could not parse arg value: ") + argValue);
|
||||
|
||||
auto attrType = attr.getType();
|
||||
|
||||
if (attrType.isa<RankedTensorType>()) {
|
||||
auto expectedTensor = convertAttrToTensor(attr);
|
||||
if (!expectedTensor)
|
||||
return expectedTensor.takeError();
|
||||
ret.push_back(std::move(*expectedTensor));
|
||||
} else if (attrType.isa<FloatType>()) {
|
||||
auto expectedFloat = convertAttrToFloat(attr);
|
||||
if (!expectedFloat)
|
||||
return expectedFloat.takeError();
|
||||
ret.push_back(refbackrt::RtValue(*expectedFloat));
|
||||
}
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
static Type convertToMLIRType(refbackrt::ElementType type, Builder &builder) {
|
||||
switch (type) {
|
||||
case refbackrt::ElementType::F32:
|
||||
return builder.getF32Type();
|
||||
default:
|
||||
llvm_unreachable("unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor,
|
||||
Builder &builder) {
|
||||
auto elementType = convertToMLIRType(tensor.getElementType(), builder);
|
||||
SmallVector<int64_t, 6> extents;
|
||||
for (int i = 0, e = tensor.getRank(); i < e; i++)
|
||||
extents.push_back(tensor.getExtent(i));
|
||||
return RankedTensorType::get(extents, elementType);
|
||||
}
|
||||
|
||||
static Attribute convertToMLIRAttribute(const refbackrt::RtValue &value,
|
||||
Builder &builder) {
|
||||
if (value.isTensor()) {
|
||||
auto& tensor = *(value.toTensor());
|
||||
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
||||
switch (tensor.getElementType()) {
|
||||
case refbackrt::ElementType::F32: {
|
||||
SmallVector<float, 100> values;
|
||||
auto *basePtr = tensor.getData<float>();
|
||||
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
||||
values.push_back(basePtr[i]);
|
||||
return DenseFPElementsAttr::get(type, values);
|
||||
}
|
||||
default:
|
||||
llvm_unreachable("unsupported element type");
|
||||
}
|
||||
} else if (value.isFloat()) {
|
||||
return builder.getF32FloatAttr(value.toFloat());
|
||||
}
|
||||
llvm_unreachable("unsupported type");
|
||||
}
|
||||
|
||||
static void printOutput(const refbackrt::RtValue &value, llvm::raw_ostream &os) {
|
||||
MLIRContext context;
|
||||
Builder builder(&context);
|
||||
auto attr = convertToMLIRAttribute(value, builder);
|
||||
attr.print(os);
|
||||
}
|
||||
|
||||
static void printOutputs(ArrayRef<refbackrt::RtValue> outputs,
|
||||
llvm::raw_ostream &os) {
|
||||
for (auto output : llvm::enumerate(outputs)) {
|
||||
os << "output #" << output.index() << ": ";
|
||||
printOutput(output.value(), os);
|
||||
os << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
Error compileAndRun(std::string mlirFile, mlir::MLIRContext &context,
|
||||
std::string invokeFunction, ArrayRef<StringRef> argValues,
|
||||
ArrayRef<StringRef> sharedLibs, bool optimize) {
|
||||
OwningModuleRef moduleRef = parseSourceFile(mlirFile, &context);
|
||||
if (!moduleRef)
|
||||
return make_string_error(Twine("could not open ") + mlirFile);
|
||||
|
||||
ModuleOp module = *moduleRef;
|
||||
|
||||
// Compile.
|
||||
PassManager pm(module.getContext(), OpPassManager::Nesting::Implicit);
|
||||
applyPassManagerCLOptions(pm);
|
||||
refback::JITModule::buildBackendCompilationPipeline(pm, optimize);
|
||||
if (failed(pm.run(module))) {
|
||||
return make_string_error(Twine("error compiling to jit backend"));
|
||||
}
|
||||
|
||||
auto expectedJitModule =
|
||||
refback::JITModule::fromCompiledModule(module, sharedLibs);
|
||||
if (!expectedJitModule)
|
||||
return expectedJitModule.takeError();
|
||||
auto jitModule = std::move(*expectedJitModule);
|
||||
|
||||
auto expectedInputs = createInputs(argValues);
|
||||
if (!expectedInputs)
|
||||
return expectedInputs.takeError();
|
||||
|
||||
auto expectedOutputs = jitModule->invoke(invokeFunction, *expectedInputs);
|
||||
if (!expectedOutputs)
|
||||
return expectedOutputs.takeError();
|
||||
|
||||
auto outputs = std::move(*expectedOutputs);
|
||||
printOutputs(outputs, llvm::outs());
|
||||
llvm::outs() << "SUCCESS\n";
|
||||
return Error::success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Main-related init and option parsing.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
namespace cl = llvm::cl;
|
||||
struct Options {
|
||||
cl::opt<std::string> inputFile{
|
||||
cl::Positional, cl::desc("the input .mlir file"), cl::init("-")};
|
||||
cl::opt<std::string> invokeFunction{"invoke", cl::Required,
|
||||
cl::desc("function to invoke")};
|
||||
cl::list<std::string> argValues{"arg-value", cl::ZeroOrMore,
|
||||
cl::desc("Arguments to the called function")};
|
||||
|
||||
cl::list<std::string> sharedLibs{"shared-libs", cl::ZeroOrMore,
|
||||
cl::MiscFlags::CommaSeparated,
|
||||
cl::desc("Libraries to link dynamically")};
|
||||
cl::opt<bool> optimize{
|
||||
"optimize", cl::Optional,
|
||||
cl::desc("whether the refback pass pipeline should run optimizations"),
|
||||
cl::init(false)};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
mlir::registerMLIRContextCLOptions();
|
||||
mlir::registerAsmPrinterCLOptions();
|
||||
mlir::registerPassManagerCLOptions();
|
||||
Options options;
|
||||
llvm::cl::ParseCommandLineOptions(argc, argv, "npcomp compile+run utility\n");
|
||||
|
||||
mlir::DialectRegistry registry;
|
||||
mlir::registerAllDialects(registry);
|
||||
mlir::registerAllPasses();
|
||||
mlir::NPCOMP::registerAllDialects(registry);
|
||||
mlir::NPCOMP::registerAllPasses();
|
||||
mlir::torch::registerAllDialects(registry);
|
||||
mlir::torch::registerAllPasses();
|
||||
MLIRContext context;
|
||||
context.appendDialectRegistry(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
llvm::InitLLVM y(argc, argv);
|
||||
npcompInitializeLLVMCodegen();
|
||||
|
||||
SmallVector<StringRef, 6> sharedLibs(options.sharedLibs.begin(),
|
||||
options.sharedLibs.end());
|
||||
SmallVector<StringRef, 6> argValues(options.argValues.begin(),
|
||||
options.argValues.end());
|
||||
Error error =
|
||||
compileAndRun(options.inputFile, context, options.invokeFunction,
|
||||
argValues, sharedLibs, options.optimize);
|
||||
|
||||
int exitCode = EXIT_SUCCESS;
|
||||
llvm::handleAllErrors(std::move(error),
|
||||
[&exitCode](const llvm::ErrorInfoBase &info) {
|
||||
llvm::errs() << "Error: ";
|
||||
info.log(llvm::errs());
|
||||
llvm::errs() << '\n';
|
||||
exitCode = EXIT_FAILURE;
|
||||
});
|
||||
return exitCode;
|
||||
}
|
Loading…
Reference in New Issue