mirror of https://github.com/llvm/torch-mlir
[RefBackend] Rename Npcomprt dialect to Refbackrt.
parent
83ad70ef54
commit
bf99a82832
|
@ -1,8 +1,8 @@
|
||||||
add_subdirectory(ATen)
|
add_subdirectory(ATen)
|
||||||
add_subdirectory(Basicpy)
|
add_subdirectory(Basicpy)
|
||||||
add_subdirectory(Npcomprt)
|
|
||||||
add_subdirectory(Numpy)
|
add_subdirectory(Numpy)
|
||||||
add_subdirectory(RefBackend)
|
add_subdirectory(RefBackend)
|
||||||
|
add_subdirectory(Refbackrt)
|
||||||
add_subdirectory(TCF)
|
add_subdirectory(TCF)
|
||||||
add_subdirectory(TCP)
|
add_subdirectory(TCP)
|
||||||
add_subdirectory(Torch)
|
add_subdirectory(Torch)
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
add_mlir_dialect(NpcomprtOps npcomprt)
|
|
|
@ -0,0 +1 @@
|
||||||
|
add_mlir_dialect(RefbackrtOps refbackrt)
|
|
@ -6,31 +6,31 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMPRT_BASE
|
#ifndef REFBACKRT_BASE
|
||||||
#define NPCOMPRT_BASE
|
#define REFBACKRT_BASE
|
||||||
|
|
||||||
include "mlir/IR/OpBase.td"
|
include "mlir/IR/OpBase.td"
|
||||||
|
|
||||||
def Npcomprt_Dialect : Dialect {
|
def Refbackrt_Dialect : Dialect {
|
||||||
let name = "npcomprt";
|
let name = "refbackrt";
|
||||||
let cppNamespace = "::mlir::NPCOMP::npcomprt";
|
let cppNamespace = "::mlir::NPCOMP::refbackrt";
|
||||||
let description = [{
|
let description = [{
|
||||||
The `npcomprt` dialect is the IR manifestation for interaction with the
|
The `refbackrt` dialect is the IR manifestation for interaction with the
|
||||||
npcomp runtime. It primarily serves as a layer that enapsulates the data
|
reference backend runtime. It primarily serves as a layer that enapsulates the
|
||||||
structures and functions available in the runtime, and faciliates
|
data structures and functions available in the runtime, and faciliates
|
||||||
conversion to those conventions, such as by providing utilities for being
|
conversion to those conventions, such as by providing utilities for being
|
||||||
lowered to the llvm dialect.
|
lowered to the llvm dialect.
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_Tensor
|
def Refbackrt_Tensor
|
||||||
: DialectType<
|
: DialectType<
|
||||||
Npcomprt_Dialect,
|
Refbackrt_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">,
|
CPred<"$_self.isa<::mlir::NPCOMP::refbackrt::TensorType>()">,
|
||||||
"npcomprt.tensor">,
|
"refbackrt.tensor">,
|
||||||
BuildableType<
|
BuildableType<
|
||||||
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> {
|
"$_builder.getType<::mlir::NPCOMP::refbackrt::TensorType>()"> {
|
||||||
let typeDescription = [{The runtime type that represents a buffer.}];
|
let typeDescription = [{The runtime type that represents a buffer.}];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // #ifndef NPCOMPRT_BASE
|
#endif // #ifndef REFBACKRT_BASE
|
|
@ -6,14 +6,14 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H
|
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
||||||
#define NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H
|
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
||||||
|
|
||||||
#include "mlir/IR/Dialect.h"
|
#include "mlir/IR/Dialect.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
namespace npcomprt {
|
namespace refbackrt {
|
||||||
|
|
||||||
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
|
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
|
||||||
public:
|
public:
|
||||||
|
@ -22,10 +22,10 @@ public:
|
||||||
static TensorType get(MLIRContext *context) { return Base::get(context); }
|
static TensorType get(MLIRContext *context) { return Base::get(context); }
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace npcomprt
|
} // namespace refbackrt
|
||||||
} // namespace NPCOMP
|
} // namespace NPCOMP
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOpsDialect.h.inc"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.h.inc"
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H
|
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
|
|
@ -6,8 +6,8 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H
|
#ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
||||||
#define NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H
|
#define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
||||||
|
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
|
@ -15,6 +15,6 @@
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h.inc"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h.inc"
|
||||||
|
|
||||||
#endif // NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H
|
#endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
|
|
@ -6,37 +6,37 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#ifndef NPCOMPRT_OPS
|
#ifndef REFBACKRT_OPS
|
||||||
#define NPCOMPRT_OPS
|
#define REFBACKRT_OPS
|
||||||
|
|
||||||
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td"
|
include "npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td"
|
||||||
include "mlir/IR/SymbolInterfaces.td"
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []>
|
class Refbackrt_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Npcomprt_Dialect, mnemonic, traits> {
|
: Op<Refbackrt_Dialect, mnemonic, traits> {
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_ToMemrefOp : Npcomprt_Op<"to_memref"> {
|
def Refbackrt_ToMemrefOp : Refbackrt_Op<"to_memref"> {
|
||||||
let summary = "Gets a memref descriptor from a tensor";
|
let summary = "Gets a memref descriptor from a tensor";
|
||||||
let description = [{
|
let description = [{
|
||||||
Gets a memref descriptor from a tensor.
|
Gets a memref descriptor from a tensor.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins Npcomprt_Tensor:$tensor);
|
let arguments = (ins Refbackrt_Tensor:$tensor);
|
||||||
let results = (outs AnyUnrankedMemRef:$memref);
|
let results = (outs AnyUnrankedMemRef:$memref);
|
||||||
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
let assemblyFormat = "$tensor attr-dict `:` type($memref)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_FromMemrefOp : Npcomprt_Op<"from_memref"> {
|
def Refbackrt_FromMemrefOp : Refbackrt_Op<"from_memref"> {
|
||||||
let summary = "Converts a memref descriptor to a tensor";
|
let summary = "Converts a memref descriptor to a tensor";
|
||||||
let description = [{
|
let description = [{
|
||||||
Copies the data from a memref into a new tensor.
|
Copies the data from a memref into a new tensor.
|
||||||
}];
|
}];
|
||||||
let arguments = (ins AnyUnrankedMemRef:$memref);
|
let arguments = (ins AnyUnrankedMemRef:$memref);
|
||||||
let results = (outs Npcomprt_Tensor:$tensor);
|
let results = (outs Refbackrt_Tensor:$tensor);
|
||||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
|
||||||
let summary = "Aborts if the predicate is true";
|
let summary = "Aborts if the predicate is true";
|
||||||
let description = [{
|
let description = [{
|
||||||
Aborts if the predicate is true.
|
Aborts if the predicate is true.
|
||||||
|
@ -46,7 +46,7 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
||||||
let assemblyFormat = "$pred `,` $msg attr-dict";
|
let assemblyFormat = "$pred `,` $msg attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
def Refbackrt_GlobalOp : Refbackrt_Op<"global", [Symbol]> {
|
||||||
let summary = "Represents a global variable";
|
let summary = "Represents a global variable";
|
||||||
let description = [{
|
let description = [{
|
||||||
Represents a global variable.
|
Represents a global variable.
|
||||||
|
@ -61,19 +61,19 @@ def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
def Refbackrt_GetGlobalOp : Refbackrt_Op<"get_global"> {
|
||||||
let summary = "Obtain a rank-erased memref pointing at the given global";
|
let summary = "Obtain a rank-erased memref pointing at the given global";
|
||||||
let description = [{
|
let description = [{
|
||||||
Obtain a rank-erased memref pointing at the given global.
|
Obtain a rank-erased memref pointing at the given global.
|
||||||
|
|
||||||
TODO: As we define the runtime layer better, we should have fewer
|
TODO: As we define the runtime layer better, we should have fewer
|
||||||
entry points that return memrefs, or at least have a clearer separation
|
entry points that return memrefs, or at least have a clearer separation
|
||||||
between the "memref world" and the "npcomprt world".
|
between the "memref world" and the "refbackrt world".
|
||||||
Something like forming IREE dispatch regions seems to be the missing thing:
|
Something like forming IREE dispatch regions seems to be the missing thing:
|
||||||
- Everything inside the dispatch regions gets things marshaled from the
|
- Everything inside the dispatch regions gets things marshaled from the
|
||||||
runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way.
|
runtime (flow/hal/refbackrt) layer to/from memrefs in a clear way.
|
||||||
- Everything outside the dispatch regions purely uses the runtime
|
- Everything outside the dispatch regions purely uses the runtime
|
||||||
(flow/hal/npcomprt) data structures.
|
(flow/hal/refbackrt) data structures.
|
||||||
Globals should be one of the things that are purely runtime data structures,
|
Globals should be one of the things that are purely runtime data structures,
|
||||||
rather than using memrefs. For now, using memrefs is simpler though.
|
rather than using memrefs. For now, using memrefs is simpler though.
|
||||||
}];
|
}];
|
||||||
|
@ -83,12 +83,12 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
||||||
let verifier = "return ::verify$cppClass(*this);";
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
|
||||||
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
||||||
]> {
|
]> {
|
||||||
let summary = "Global metadata for the module";
|
let summary = "Global metadata for the module";
|
||||||
let description = [{
|
let description = [{
|
||||||
This op contains a region containing npcomprt.func_metadata ops,
|
This op contains a region containing refbackrt.func_metadata ops,
|
||||||
which give information about the functions in the module. This allows
|
which give information about the functions in the module. This allows
|
||||||
the module to be introspected when it is loaded, such as looking up
|
the module to be introspected when it is loaded, such as looking up
|
||||||
functions.
|
functions.
|
||||||
|
@ -110,8 +110,8 @@ def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_ModuleMetadataTerminatorOp
|
def Refbackrt_ModuleMetadataTerminatorOp
|
||||||
: Npcomprt_Op<"module_metadata_terminator",
|
: Refbackrt_Op<"module_metadata_terminator",
|
||||||
[Terminator, HasParent<"ModuleMetadataOp">]> {
|
[Terminator, HasParent<"ModuleMetadataOp">]> {
|
||||||
let summary = "Implicit terminator for ModuleMetadataOp's region";
|
let summary = "Implicit terminator for ModuleMetadataOp's region";
|
||||||
let arguments = (ins);
|
let arguments = (ins);
|
||||||
|
@ -119,8 +119,8 @@ def Npcomprt_ModuleMetadataTerminatorOp
|
||||||
let assemblyFormat = "attr-dict";
|
let assemblyFormat = "attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
def Npcomprt_FuncMetadataOp
|
def Refbackrt_FuncMetadataOp
|
||||||
: Npcomprt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> {
|
: Refbackrt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> {
|
||||||
let summary = "Runtime metadata for a single func";
|
let summary = "Runtime metadata for a single func";
|
||||||
let description = [{
|
let description = [{
|
||||||
Runtime metadata for a single func.
|
Runtime metadata for a single func.
|
||||||
|
@ -138,4 +138,4 @@ def Npcomprt_FuncMetadataOp
|
||||||
let verifier = [{ return ::verify(*this); }];
|
let verifier = [{ return ::verify(*this); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // #ifndef NPCOMPRT_OPS
|
#endif // #ifndef REFBACKRT_OPS
|
|
@ -24,7 +24,7 @@ class PassManager;
|
||||||
} // namespace mlir
|
} // namespace mlir
|
||||||
|
|
||||||
namespace npcomp {
|
namespace npcomp {
|
||||||
// Wrapper around npcomprt data structures and a JITted module, facilitating
|
// Wrapper around refbackrt data structures and a JITted module, facilitating
|
||||||
// interaction.
|
// interaction.
|
||||||
class JITModule {
|
class JITModule {
|
||||||
public:
|
public:
|
||||||
|
@ -40,14 +40,14 @@ public:
|
||||||
fromCompiledModule(mlir::ModuleOp module,
|
fromCompiledModule(mlir::ModuleOp module,
|
||||||
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
llvm::ArrayRef<llvm::StringRef> sharedLibs);
|
||||||
|
|
||||||
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>>
|
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||||
invoke(llvm::StringRef functionName,
|
invoke(llvm::StringRef functionName,
|
||||||
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs);
|
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
JITModule();
|
JITModule();
|
||||||
std::unique_ptr<mlir::ExecutionEngine> engine;
|
std::unique_ptr<mlir::ExecutionEngine> engine;
|
||||||
npcomprt::ModuleDescriptor *descriptor;
|
refbackrt::ModuleDescriptor *descriptor;
|
||||||
};
|
};
|
||||||
} // namespace npcomp
|
} // namespace npcomp
|
||||||
|
|
||||||
|
|
|
@ -52,10 +52,10 @@ def LowerStructuralToMemref :
|
||||||
let constructor = "mlir::NPCOMP::createLowerStructuralToMemrefPass()";
|
let constructor = "mlir::NPCOMP::createLowerStructuralToMemrefPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> {
|
def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
|
||||||
let summary = "Lower constructs requiring runtime support to `npcomprt`";
|
let summary = "Lower constructs requiring runtime support to `refbackrt`";
|
||||||
let description = [{
|
let description = [{
|
||||||
We have a specialized dialect `npcomprt` which models our runtime's data
|
We have a specialized dialect `refbackrt` which models our runtime's data
|
||||||
structures, and function signatures (and presumably eventually, other
|
structures, and function signatures (and presumably eventually, other
|
||||||
ABI boundaries like external calls if we ever support it) will be
|
ABI boundaries like external calls if we ever support it) will be
|
||||||
converted.
|
converted.
|
||||||
|
@ -65,7 +65,7 @@ def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> {
|
||||||
- globals
|
- globals
|
||||||
- error handling
|
- error handling
|
||||||
}];
|
}];
|
||||||
let constructor = "mlir::NPCOMP::createLowerToNpcomprtABIPass()";
|
let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
|
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {
|
||||||
|
|
|
@ -34,7 +34,7 @@ createLowerConstantTensorsToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>> createLowerToNpcomprtABIPass();
|
std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
RefBackendRt (namespace `refbackrt`) is the runtime support library for the
|
Refbackrt (namespace `refbackrt`) is the runtime support library for the
|
||||||
RefBackend backend. It is best practice to keep compiler and runtime code
|
RefBackend backend. It is best practice to keep compiler and runtime code
|
||||||
totally firewalled.
|
totally firewalled.
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
namespace npcomprt {
|
namespace refbackrt {
|
||||||
class StringRef {
|
class StringRef {
|
||||||
public:
|
public:
|
||||||
StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){};
|
StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){};
|
||||||
|
@ -94,6 +94,6 @@ inline bool failed(LogicalResult result) {
|
||||||
return result.value == LogicalResult::Failure;
|
return result.value == LogicalResult::Failure;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace npcomprt
|
} // namespace refbackrt
|
||||||
|
|
||||||
#endif // NPCOMP_RUNTIME_SUPPORT_H
|
#endif // NPCOMP_RUNTIME_SUPPORT_H
|
||||||
|
|
|
@ -25,7 +25,7 @@
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
namespace npcomprt {
|
namespace refbackrt {
|
||||||
|
|
||||||
// Reference-counted handle to a type with a `refCount` member.
|
// Reference-counted handle to a type with a `refCount` member.
|
||||||
template <typename T> class Ref {
|
template <typename T> class Ref {
|
||||||
|
@ -178,6 +178,6 @@ LogicalResult getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||||
StringRef functionName,
|
StringRef functionName,
|
||||||
FunctionMetadata &outMetadata);
|
FunctionMetadata &outMetadata);
|
||||||
|
|
||||||
} // namespace npcomprt
|
} // namespace refbackrt
|
||||||
|
|
||||||
#endif // NPCOMP_RUNTIME_USERAPI_H
|
#endif // NPCOMP_RUNTIME_USERAPI_H
|
||||||
|
|
|
@ -22,8 +22,8 @@ using llvm::Twine;
|
||||||
using mlir::PyModuleOp;
|
using mlir::PyModuleOp;
|
||||||
using mlir::PyPassManager;
|
using mlir::PyPassManager;
|
||||||
using npcomp::JITModule;
|
using npcomp::JITModule;
|
||||||
using npcomprt::Ref;
|
using refbackrt::Ref;
|
||||||
using npcomprt::Tensor;
|
using refbackrt::Tensor;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
||||||
|
@ -37,10 +37,10 @@ static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
|
||||||
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
static npcomprt::ElementType
|
static refbackrt::ElementType
|
||||||
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
|
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
|
||||||
if (format == "f")
|
if (format == "f")
|
||||||
return npcomprt::ElementType::F32;
|
return refbackrt::ElementType::F32;
|
||||||
|
|
||||||
std::string message("unsupported buffer format: ");
|
std::string message("unsupported buffer format: ");
|
||||||
message.append(format);
|
message.append(format);
|
||||||
|
@ -61,7 +61,7 @@ static Ref<Tensor> copyBufferToTensor(py::buffer buffer) {
|
||||||
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
// TODO: Switch Tensor extents to ssize_t for efficiency.
|
||||||
SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end());
|
SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end());
|
||||||
return Tensor::create(
|
return Tensor::create(
|
||||||
npcomprt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
||||||
elementType, info.ptr);
|
elementType, info.ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ py::array wrapTensorAsArray(Ref<Tensor> tensor) {
|
||||||
|
|
||||||
const char *format;
|
const char *format;
|
||||||
switch (tensor->getElementType()) {
|
switch (tensor->getElementType()) {
|
||||||
case npcomprt::ElementType::F32:
|
case refbackrt::ElementType::F32:
|
||||||
format = "f";
|
format = "f";
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
|
|
@ -43,7 +43,7 @@ add_mlir_library(NPCOMPInitAll
|
||||||
NPCOMPTCPDialect
|
NPCOMPTCPDialect
|
||||||
NPCOMPTCFDialect
|
NPCOMPTCFDialect
|
||||||
NPCOMPTorchDialect
|
NPCOMPTorchDialect
|
||||||
NPCOMPNpcomprtDialect
|
NPCOMPRefbackrtDialect
|
||||||
NPCOMPATenDialect
|
NPCOMPATenDialect
|
||||||
NPCOMPBasicpyDialect
|
NPCOMPBasicpyDialect
|
||||||
NPCOMPBasicpyPasses
|
NPCOMPBasicpyPasses
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
add_subdirectory(ATen)
|
add_subdirectory(ATen)
|
||||||
add_subdirectory(Basicpy)
|
add_subdirectory(Basicpy)
|
||||||
add_subdirectory(Npcomprt)
|
|
||||||
add_subdirectory(Numpy)
|
add_subdirectory(Numpy)
|
||||||
add_subdirectory(RefBackend)
|
add_subdirectory(RefBackend)
|
||||||
|
add_subdirectory(Refbackrt)
|
||||||
add_subdirectory(TCF)
|
add_subdirectory(TCF)
|
||||||
add_subdirectory(TCP)
|
add_subdirectory(TCP)
|
||||||
add_subdirectory(Torch)
|
add_subdirectory(Torch)
|
||||||
|
|
|
@ -1,17 +0,0 @@
|
||||||
add_mlir_dialect_library(NPCOMPNpcomprtDialect
|
|
||||||
NpcomprtDialect.cpp
|
|
||||||
NpcomprtOps.cpp
|
|
||||||
|
|
||||||
ADDITIONAL_HEADER_DIRS
|
|
||||||
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Npcomprt
|
|
||||||
|
|
||||||
DEPENDS
|
|
||||||
MLIRNpcomprtOpsIncGen
|
|
||||||
|
|
||||||
LINK_COMPONENTS
|
|
||||||
Core
|
|
||||||
|
|
||||||
LINK_LIBS PUBLIC
|
|
||||||
MLIRIR
|
|
||||||
MLIRSupport
|
|
||||||
)
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
add_mlir_dialect_library(NPCOMPRefbackrtDialect
|
||||||
|
RefbackrtDialect.cpp
|
||||||
|
RefbackrtOps.cpp
|
||||||
|
|
||||||
|
ADDITIONAL_HEADER_DIRS
|
||||||
|
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refbackrt
|
||||||
|
|
||||||
|
DEPENDS
|
||||||
|
MLIRRefbackrtOpsIncGen
|
||||||
|
|
||||||
|
LINK_COMPONENTS
|
||||||
|
Core
|
||||||
|
|
||||||
|
LINK_LIBS PUBLIC
|
||||||
|
MLIRIR
|
||||||
|
MLIRSupport
|
||||||
|
)
|
|
@ -6,23 +6,23 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
#include "mlir/IR/DialectImplementation.h"
|
#include "mlir/IR/DialectImplementation.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||||
#include "llvm/ADT/TypeSwitch.h"
|
#include "llvm/ADT/TypeSwitch.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::npcomprt;
|
using namespace mlir::NPCOMP::refbackrt;
|
||||||
|
|
||||||
void NpcomprtDialect::initialize() {
|
void RefbackrtDialect::initialize() {
|
||||||
addOperations<
|
addOperations<
|
||||||
#define GET_OP_LIST
|
#define GET_OP_LIST
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.cpp.inc"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
|
||||||
>();
|
>();
|
||||||
addTypes<TensorType>();
|
addTypes<TensorType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Type NpcomprtDialect::parseType(DialectAsmParser &parser) const {
|
Type RefbackrtDialect::parseType(DialectAsmParser &parser) const {
|
||||||
StringRef keyword;
|
StringRef keyword;
|
||||||
if (parser.parseKeyword(&keyword))
|
if (parser.parseKeyword(&keyword))
|
||||||
return Type();
|
return Type();
|
||||||
|
@ -30,14 +30,14 @@ Type NpcomprtDialect::parseType(DialectAsmParser &parser) const {
|
||||||
if (keyword == "tensor")
|
if (keyword == "tensor")
|
||||||
return TensorType::get(getContext());
|
return TensorType::get(getContext());
|
||||||
|
|
||||||
parser.emitError(parser.getNameLoc(), "unknown type in 'npcomprt' dialect: ")
|
parser.emitError(parser.getNameLoc(), "unknown type in 'refbackrt' dialect: ")
|
||||||
<< keyword;
|
<< keyword;
|
||||||
return Type();
|
return Type();
|
||||||
}
|
}
|
||||||
|
|
||||||
void NpcomprtDialect::printType(Type type, DialectAsmPrinter &os) const {
|
void RefbackrtDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||||
TypeSwitch<Type>(type)
|
TypeSwitch<Type>(type)
|
||||||
.Case<NPCOMP::npcomprt::TensorType>([&](Type) { os << "tensor"; })
|
.Case<NPCOMP::refbackrt::TensorType>([&](Type) { os << "tensor"; })
|
||||||
.Default(
|
.Default(
|
||||||
[&](Type) { llvm_unreachable("unexpected 'npcomprt' type kind"); });
|
[&](Type) { llvm_unreachable("unexpected 'refbackrt' type kind"); });
|
||||||
}
|
}
|
|
@ -6,22 +6,22 @@
|
||||||
//
|
//
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::npcomprt;
|
using namespace mlir::NPCOMP::refbackrt;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// GlobalOp
|
// GlobalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||||
p << "npcomprt.global ";
|
p << "refbackrt.global ";
|
||||||
p.printSymbolName(op.sym_name());
|
p.printSymbolName(op.sym_name());
|
||||||
p << ' ';
|
p << ' ';
|
||||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||||
|
@ -49,7 +49,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||||
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
||||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||||
if (!global)
|
if (!global)
|
||||||
return op.emitError() << "must reference a valid npcomprt.global";
|
return op.emitError() << "must reference a valid refbackrt.global";
|
||||||
auto globalType = global.value().getType();
|
auto globalType = global.value().getType();
|
||||||
auto resultType = op.getType().cast<ShapedType>();
|
auto resultType = op.getType().cast<ShapedType>();
|
||||||
if (globalType.getElementType() != resultType.getElementType())
|
if (globalType.getElementType() != resultType.getElementType())
|
||||||
|
@ -62,7 +62,7 @@ static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
|
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
|
||||||
p << "npcomprt.module_metadata";
|
p << "refbackrt.module_metadata";
|
||||||
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
p.printOptionalAttrDictWithKeyword(op.getAttrs());
|
||||||
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
|
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
|
||||||
/*printBlockTerminators=*/false);
|
/*printBlockTerminators=*/false);
|
||||||
|
@ -99,4 +99,4 @@ static LogicalResult verify(FuncMetadataOp op) {
|
||||||
}
|
}
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.cpp.inc"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
|
|
@ -12,7 +12,7 @@
|
||||||
#include "npcomp/Dialect/ATen/ATenPasses.h"
|
#include "npcomp/Dialect/ATen/ATenPasses.h"
|
||||||
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
|
||||||
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
|
||||||
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
#include "npcomp/Dialect/Numpy/Transforms/Passes.h"
|
||||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
|
||||||
|
@ -73,7 +73,7 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry ®istry) {
|
||||||
registry.insert<mlir::NPCOMP::aten::ATenDialect,
|
registry.insert<mlir::NPCOMP::aten::ATenDialect,
|
||||||
Basicpy::BasicpyDialect,
|
Basicpy::BasicpyDialect,
|
||||||
Numpy::NumpyDialect,
|
Numpy::NumpyDialect,
|
||||||
npcomprt::NpcomprtDialect,
|
refbackrt::RefbackrtDialect,
|
||||||
refback::RefBackendDialect,
|
refback::RefBackendDialect,
|
||||||
tcf::TCFDialect,
|
tcf::TCFDialect,
|
||||||
tcp::TCPDialect,
|
tcp::TCPDialect,
|
||||||
|
|
|
@ -5,7 +5,7 @@ add_mlir_library(NPCOMPRefBackend
|
||||||
BypassShapes.cpp
|
BypassShapes.cpp
|
||||||
RefBackend.cpp
|
RefBackend.cpp
|
||||||
LowerToLLVM.cpp
|
LowerToLLVM.cpp
|
||||||
LowerToNpcomprtABI.cpp
|
LowerToRefbackrtABI.cpp
|
||||||
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
TensorToMemref/LowerConstantTensorsToMemref.cpp
|
||||||
TensorToMemref/LowerShapedResultsToMemref.cpp
|
TensorToMemref/LowerShapedResultsToMemref.cpp
|
||||||
TensorToMemref/LowerStdToMemref.cpp
|
TensorToMemref/LowerStdToMemref.cpp
|
||||||
|
|
|
@ -51,39 +51,40 @@ JITModule::fromCompiledModule(mlir::ModuleOp module,
|
||||||
if (!expectedAddress)
|
if (!expectedAddress)
|
||||||
return expectedAddress.takeError();
|
return expectedAddress.takeError();
|
||||||
ret->descriptor =
|
ret->descriptor =
|
||||||
reinterpret_cast<npcomprt::ModuleDescriptor *>(*expectedAddress);
|
reinterpret_cast<refbackrt::ModuleDescriptor *>(*expectedAddress);
|
||||||
return std::move(ret);
|
return std::move(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Converter for bridging to npcomprt llvm-lookalike data structures.
|
// Converter for bridging to refbackrt llvm-lookalike data structures.
|
||||||
static npcomprt::StringRef toNpcomprt(llvm::StringRef s) {
|
static refbackrt::StringRef toRefbackrt(llvm::StringRef s) {
|
||||||
return npcomprt::StringRef(s.data(), s.size());
|
return refbackrt::StringRef(s.data(), s.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static npcomprt::ArrayRef<T> toNpcomprt(llvm::ArrayRef<T> a) {
|
static refbackrt::ArrayRef<T> toRefbackrt(llvm::ArrayRef<T> a) {
|
||||||
return npcomprt::ArrayRef<T>(a.data(), a.size());
|
return refbackrt::ArrayRef<T>(a.data(), a.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static npcomprt::MutableArrayRef<T> toNpcomprt(llvm::MutableArrayRef<T> a) {
|
static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
|
||||||
return npcomprt::MutableArrayRef<T>(a.data(), a.size());
|
return refbackrt::MutableArrayRef<T>(a.data(), a.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>>
|
llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||||
JITModule::invoke(llvm::StringRef functionName,
|
JITModule::invoke(llvm::StringRef functionName,
|
||||||
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs) {
|
llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) {
|
||||||
npcomprt::FunctionMetadata metadata;
|
refbackrt::FunctionMetadata metadata;
|
||||||
if (npcomprt::failed(npcomprt::getMetadata(
|
if (refbackrt::failed(refbackrt::getMetadata(
|
||||||
descriptor, toNpcomprt(functionName), metadata)))
|
descriptor, toRefbackrt(functionName), metadata)))
|
||||||
return make_string_error("unknown function: " + Twine(functionName));
|
return make_string_error("unknown function: " + Twine(functionName));
|
||||||
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> outputs(metadata.numOutputs);
|
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs(
|
||||||
|
metadata.numOutputs);
|
||||||
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
|
||||||
return make_string_error("invoking '" + Twine(functionName) +
|
return make_string_error("invoking '" + Twine(functionName) +
|
||||||
"': expected " + Twine(metadata.numInputs) +
|
"': expected " + Twine(metadata.numInputs) +
|
||||||
" inputs");
|
" inputs");
|
||||||
npcomprt::invoke(
|
refbackrt::invoke(
|
||||||
descriptor, toNpcomprt(functionName), toNpcomprt(inputs),
|
descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
|
||||||
toNpcomprt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,8 +15,8 @@
|
||||||
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
#include "mlir/Dialect/StandardOps/Transforms/Passes.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
|
@ -29,7 +29,7 @@ using mlir::LLVM::LLVMType;
|
||||||
// These correspond to the types in CompilerDataStructures.h
|
// These correspond to the types in CompilerDataStructures.h
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Get the LLVMType for npcomprt::FuncDescriptor.
|
// Get the LLVMType for refbackrt::FuncDescriptor.
|
||||||
static LLVMType getFuncDescriptorTy(MLIRContext *context) {
|
static LLVMType getFuncDescriptorTy(MLIRContext *context) {
|
||||||
return LLVMType::getStructTy(context, {
|
return LLVMType::getStructTy(context, {
|
||||||
// Name length.
|
// Name length.
|
||||||
|
@ -45,7 +45,7 @@ static LLVMType getFuncDescriptorTy(MLIRContext *context) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the LLVMType for npcomprt::ModuleDescriptor.
|
// Get the LLVMType for refbackrt::ModuleDescriptor.
|
||||||
static LLVMType getModuleDescriptorTy(MLIRContext *context) {
|
static LLVMType getModuleDescriptorTy(MLIRContext *context) {
|
||||||
return LLVMType::getStructTy(context,
|
return LLVMType::getStructTy(context,
|
||||||
{
|
{
|
||||||
|
@ -56,7 +56,7 @@ static LLVMType getModuleDescriptorTy(MLIRContext *context) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
// Get the LLVMType for refbackrt::GlobalDescriptor.
|
||||||
static LLVMType getGlobalDescriptorTy(MLIRContext *context) {
|
static LLVMType getGlobalDescriptorTy(MLIRContext *context) {
|
||||||
return LLVMType::getStructTy(
|
return LLVMType::getStructTy(
|
||||||
// std::int32_t numExtents;
|
// std::int32_t numExtents;
|
||||||
|
@ -97,13 +97,13 @@ namespace {
|
||||||
// FromMemrefOp requires special handling so that the unranked memref descriptor
|
// FromMemrefOp requires special handling so that the unranked memref descriptor
|
||||||
// gets passed as two separate arguments instead of as a struct.
|
// gets passed as two separate arguments instead of as a struct.
|
||||||
class FromMemrefOpCompilerRuntimeLowering
|
class FromMemrefOpCompilerRuntimeLowering
|
||||||
: public OpConversionPattern<npcomprt::FromMemrefOp> {
|
: public OpConversionPattern<refbackrt::FromMemrefOp> {
|
||||||
public:
|
public:
|
||||||
FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||||
: OpConversionPattern<npcomprt::FromMemrefOp>(backingFunc.getContext()),
|
: OpConversionPattern<refbackrt::FromMemrefOp>(backingFunc.getContext()),
|
||||||
backingFunc(backingFunc) {}
|
backingFunc(backingFunc) {}
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(npcomprt::FromMemrefOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refbackrt::FromMemrefOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto structVal = operands[0];
|
auto structVal = operands[0];
|
||||||
Value rank = rewriter.create<LLVM::ExtractValueOp>(
|
Value rank = rewriter.create<LLVM::ExtractValueOp>(
|
||||||
|
@ -124,17 +124,17 @@ public:
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class GetGlobalOpCompilerRuntimeLowering
|
class GetGlobalOpCompilerRuntimeLowering
|
||||||
: public OpConversionPattern<npcomprt::GetGlobalOp> {
|
: public OpConversionPattern<refbackrt::GetGlobalOp> {
|
||||||
public:
|
public:
|
||||||
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||||
: OpConversionPattern<npcomprt::GetGlobalOp>(backingFunc.getContext()),
|
: OpConversionPattern<refbackrt::GetGlobalOp>(backingFunc.getContext()),
|
||||||
backingFunc(backingFunc) {}
|
backingFunc(backingFunc) {}
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refbackrt::GetGlobalOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
// It would be nice if we could use the constructor here that takes just the
|
// It would be nice if we could use the constructor here that takes just the
|
||||||
// global, but keeping track of the converted llvm.mlir.global op that gets
|
// global, but keeping track of the converted llvm.mlir.global op that gets
|
||||||
// created from the npcomprt.global while conversion is going on is a
|
// created from the refbackrt.global while conversion is going on is a
|
||||||
// headache.
|
// headache.
|
||||||
//
|
//
|
||||||
// Instead, we rely on the symbol name being the same and the result type
|
// Instead, we rely on the symbol name being the same and the result type
|
||||||
|
@ -178,15 +178,15 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class AbortIfOpCompilerRuntimeLowering
|
class AbortIfOpCompilerRuntimeLowering
|
||||||
: public OpConversionPattern<npcomprt::AbortIfOp> {
|
: public OpConversionPattern<refbackrt::AbortIfOp> {
|
||||||
public:
|
public:
|
||||||
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||||
: OpConversionPattern<npcomprt::AbortIfOp>(backingFunc.getContext()),
|
: OpConversionPattern<refbackrt::AbortIfOp>(backingFunc.getContext()),
|
||||||
backingFunc(backingFunc) {}
|
backingFunc(backingFunc) {}
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(npcomprt::AbortIfOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refbackrt::AbortIfOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
npcomprt::AbortIfOp::Adaptor adaptor(operands);
|
refbackrt::AbortIfOp::Adaptor adaptor(operands);
|
||||||
auto *context = op.getContext();
|
auto *context = op.getContext();
|
||||||
|
|
||||||
// Create the global string, take its address, and gep to get an `i8*`.
|
// Create the global string, take its address, and gep to get an `i8*`.
|
||||||
|
@ -207,7 +207,7 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
// Create the LLVM runtime function backing the refbackrt op with name `name`
|
||||||
// and requiring `type`.
|
// and requiring `type`.
|
||||||
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
||||||
OpBuilder &builder,
|
OpBuilder &builder,
|
||||||
|
@ -242,12 +242,12 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||||
|
|
||||||
{
|
{
|
||||||
auto mlirFunctionType = builder.getFunctionType(
|
auto mlirFunctionType = builder.getFunctionType(
|
||||||
{builder.getType<npcomprt::TensorType>()},
|
{builder.getType<refbackrt::TensorType>()},
|
||||||
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
|
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
|
||||||
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
||||||
LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl(
|
LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl(
|
||||||
"to_memref", funcTy, builder, module.getLoc());
|
"to_memref", funcTy, builder, module.getLoc());
|
||||||
patterns.insert<TrivialCompilerRuntimeLowering<npcomprt::ToMemrefOp>>(
|
patterns.insert<TrivialCompilerRuntimeLowering<refbackrt::ToMemrefOp>>(
|
||||||
toMemrefFunc);
|
toMemrefFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||||
// doesn't know its own dtype.
|
// doesn't know its own dtype.
|
||||||
auto mlirFunctionType = builder.getFunctionType(
|
auto mlirFunctionType = builder.getFunctionType(
|
||||||
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)},
|
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)},
|
||||||
{builder.getType<npcomprt::TensorType>()});
|
{builder.getType<refbackrt::TensorType>()});
|
||||||
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
||||||
LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl(
|
LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl(
|
||||||
"from_memref", funcTy, builder, module.getLoc());
|
"from_memref", funcTy, builder, module.getLoc());
|
||||||
|
@ -277,24 +277,24 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Lowering for npcomprt.global
|
// Lowering for refbackrt.global
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerNpcomprtGlobalOp : public OpConversionPattern<npcomprt::GlobalOp> {
|
class LowerRefbackrtGlobalOp : public OpConversionPattern<refbackrt::GlobalOp> {
|
||||||
public:
|
public:
|
||||||
explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter)
|
explicit LowerRefbackrtGlobalOp(LLVMTypeConverter &typeConverter)
|
||||||
: OpConversionPattern<npcomprt::GlobalOp>(&typeConverter.getContext()),
|
: OpConversionPattern<refbackrt::GlobalOp>(&typeConverter.getContext()),
|
||||||
typeConverter(typeConverter) {}
|
typeConverter(typeConverter) {}
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(npcomprt::GlobalOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refbackrt::GlobalOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto *context = rewriter.getContext();
|
auto *context = rewriter.getContext();
|
||||||
auto globalDescriptorTy = getGlobalDescriptorTy(context);
|
auto globalDescriptorTy = getGlobalDescriptorTy(context);
|
||||||
|
|
||||||
// Create the data buffer.
|
// Create the data buffer.
|
||||||
auto dataBuffer = createGlobalForDenseElementsAttr(
|
auto dataBuffer = createGlobalForDenseElementsAttr(
|
||||||
(Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(),
|
(Twine("__refbackrt_global_data_buffer_") + op.sym_name()).str(),
|
||||||
op.value().cast<DenseElementsAttr>(), op, rewriter);
|
op.value().cast<DenseElementsAttr>(), op, rewriter);
|
||||||
|
|
||||||
// Create the extents buffer.
|
// Create the extents buffer.
|
||||||
|
@ -302,8 +302,8 @@ public:
|
||||||
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
|
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
|
||||||
[](int64_t i) -> int32_t { return i; })));
|
[](int64_t i) -> int32_t { return i; })));
|
||||||
auto extentsBuffer = createGlobalForDenseElementsAttr(
|
auto extentsBuffer = createGlobalForDenseElementsAttr(
|
||||||
(Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32,
|
(Twine("__refbackrt_global_extents_") + op.sym_name()).str(),
|
||||||
op, rewriter);
|
extentsI32, op, rewriter);
|
||||||
|
|
||||||
// Create the GlobalDescriptor.
|
// Create the GlobalDescriptor.
|
||||||
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
|
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
|
||||||
|
@ -352,7 +352,7 @@ public:
|
||||||
private:
|
private:
|
||||||
// TODO: It feels like MLIR core should have better utilities for this.
|
// TODO: It feels like MLIR core should have better utilities for this.
|
||||||
LLVM::GlobalOp createGlobalForDenseElementsAttr(
|
LLVM::GlobalOp createGlobalForDenseElementsAttr(
|
||||||
StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op,
|
StringRef symbolName, DenseElementsAttr elements, refbackrt::GlobalOp op,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto type = elements.getType().cast<ShapedType>();
|
auto type = elements.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
@ -384,7 +384,7 @@ private:
|
||||||
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
|
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
|
||||||
}
|
}
|
||||||
|
|
||||||
LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op,
|
LLVMType getLLVMTypeForShapedType(ShapedType type, refbackrt::GlobalOp op,
|
||||||
ConversionPatternRewriter &rewriter) const {
|
ConversionPatternRewriter &rewriter) const {
|
||||||
auto llvmType =
|
auto llvmType =
|
||||||
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
|
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
|
||||||
|
@ -425,7 +425,7 @@ private:
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
static LLVM::GlobalOp
|
static LLVM::GlobalOp
|
||||||
createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas,
|
createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
|
||||||
OpBuilder &builder, Location loc) {
|
OpBuilder &builder, Location loc) {
|
||||||
auto llvmI32Ty = LLVMType::getIntNTy(builder.getContext(), 32);
|
auto llvmI32Ty = LLVMType::getIntNTy(builder.getContext(), 32);
|
||||||
|
|
||||||
|
@ -556,14 +556,14 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
class LowerModuleMetadata
|
class LowerModuleMetadata
|
||||||
: public OpConversionPattern<npcomprt::ModuleMetadataOp> {
|
: public OpConversionPattern<refbackrt::ModuleMetadataOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(npcomprt::ModuleMetadataOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refbackrt::ModuleMetadataOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto funcMetadatas =
|
auto funcMetadatas =
|
||||||
llvm::to_vector<6>(op.metadatas().getOps<npcomprt::FuncMetadataOp>());
|
llvm::to_vector<6>(op.metadatas().getOps<refbackrt::FuncMetadataOp>());
|
||||||
auto funcDescriptorArray =
|
auto funcDescriptorArray =
|
||||||
createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc());
|
createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc());
|
||||||
auto moduleDescriptor =
|
auto moduleDescriptor =
|
||||||
|
@ -646,7 +646,7 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
|
||||||
// Construct a wrapper function.
|
// Construct a wrapper function.
|
||||||
// For an externally visible function f(T1, T2) -> T3, T4, we create a
|
// For an externally visible function f(T1, T2) -> T3, T4, we create a
|
||||||
// wrapper
|
// wrapper
|
||||||
// __npcomprt_wrapper_f(void **inputs, void ** outputs) {
|
// __refbackrt_wrapper_f(void **inputs, void ** outputs) {
|
||||||
// T3 t3;
|
// T3 t3;
|
||||||
// T4 t4;
|
// T4 t4;
|
||||||
// (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1]));
|
// (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1]));
|
||||||
|
@ -664,8 +664,8 @@ static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
|
||||||
auto wrapperTy = LLVMType::getFunctionTy(LLVMType::getVoidTy(context),
|
auto wrapperTy = LLVMType::getFunctionTy(LLVMType::getVoidTy(context),
|
||||||
{voidStarStarTy, voidStarStarTy},
|
{voidStarStarTy, voidStarStarTy},
|
||||||
/*isVarArg=*/false);
|
/*isVarArg=*/false);
|
||||||
constexpr char kNpcomprtWrapperPrefix[] = "__npcomprt_wrapper_";
|
constexpr char kRefbackrtWrapperPrefix[] = "__refbackrt_wrapper_";
|
||||||
auto wrapperName = (Twine(kNpcomprtWrapperPrefix) + func.getName()).str();
|
auto wrapperName = (Twine(kRefbackrtWrapperPrefix) + func.getName()).str();
|
||||||
OpBuilder moduleBuilder(func.getParentRegion());
|
OpBuilder moduleBuilder(func.getParentRegion());
|
||||||
LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>(
|
LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>(
|
||||||
func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External);
|
func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External);
|
||||||
|
@ -693,8 +693,8 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
|
|
||||||
LLVMTypeConverter converter(context);
|
LLVMTypeConverter converter(context);
|
||||||
|
|
||||||
// npcomprt::TensorType is passed as a `void*` in the ABI.
|
// refbackrt::TensorType is passed as a `void*` in the ABI.
|
||||||
converter.addConversion([&](npcomprt::TensorType type) {
|
converter.addConversion([&](refbackrt::TensorType type) {
|
||||||
return LLVMType::getInt8PtrTy(context);
|
return LLVMType::getInt8PtrTy(context);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -706,7 +706,7 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
patterns.insert<LowerModuleMetadata>(context);
|
patterns.insert<LowerModuleMetadata>(context);
|
||||||
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
patterns.insert<LowerRefbackrtGlobalOp>(converter);
|
||||||
|
|
||||||
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
// TODO: Move these "std to std" legalizations to their own pass if we grow
|
||||||
// lots of these patterns.
|
// lots of these patterns.
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
#include "mlir/IR/Verifier.h"
|
#include "mlir/IR/Verifier.h"
|
||||||
#include "mlir/Transforms/DialectConversion.h"
|
#include "mlir/Transforms/DialectConversion.h"
|
||||||
|
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h"
|
#include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
|
||||||
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
|
@ -32,10 +32,10 @@ static Type getABIMemrefType(Type type) {
|
||||||
// Creating module metadata.
|
// Creating module metadata.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
// Returns true if the function signature can be expressed with the npcomprt
|
// Returns true if the function signature can be expressed with the refbackrt
|
||||||
// ABI.
|
// ABI.
|
||||||
static bool expressibleWithNpcomprtABI(FunctionType type) {
|
static bool expressibleWithRefbackrtABI(FunctionType type) {
|
||||||
// Currently, only memref types can be exposed at npcomprt ABI boundaries.
|
// Currently, only memref types can be exposed at refbackrt ABI boundaries.
|
||||||
return llvm::all_of(
|
return llvm::all_of(
|
||||||
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
llvm::concat<const Type>(type.getInputs(), type.getResults()),
|
||||||
[](Type t) { return t.isa<MemRefType>(); });
|
[](Type t) { return t.isa<MemRefType>(); });
|
||||||
|
@ -44,11 +44,11 @@ static bool expressibleWithNpcomprtABI(FunctionType type) {
|
||||||
static LogicalResult createModuleMetadata(ModuleOp module) {
|
static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
auto moduleMetadata =
|
auto moduleMetadata =
|
||||||
OpBuilder::atBlockBegin(module.getBody())
|
OpBuilder::atBlockBegin(module.getBody())
|
||||||
.create<npcomprt::ModuleMetadataOp>(module.getLoc());
|
.create<refbackrt::ModuleMetadataOp>(module.getLoc());
|
||||||
moduleMetadata.metadatas().push_back(new Block);
|
moduleMetadata.metadatas().push_back(new Block);
|
||||||
Block &metadatas = moduleMetadata.metadatas().front();
|
Block &metadatas = moduleMetadata.metadatas().front();
|
||||||
OpBuilder::atBlockEnd(&metadatas)
|
OpBuilder::atBlockEnd(&metadatas)
|
||||||
.create<npcomprt::ModuleMetadataTerminatorOp>(module.getLoc());
|
.create<refbackrt::ModuleMetadataTerminatorOp>(module.getLoc());
|
||||||
|
|
||||||
SymbolTable symbolTable(module);
|
SymbolTable symbolTable(module);
|
||||||
auto builder = OpBuilder::atBlockBegin(&metadatas);
|
auto builder = OpBuilder::atBlockBegin(&metadatas);
|
||||||
|
@ -59,13 +59,13 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
|
||||||
}
|
}
|
||||||
// TODO: Add richer information here such as expected shapes and element
|
// TODO: Add richer information here such as expected shapes and element
|
||||||
// types.
|
// types.
|
||||||
builder.create<npcomprt::FuncMetadataOp>(
|
builder.create<refbackrt::FuncMetadataOp>(
|
||||||
func.getLoc(), builder.getSymbolRefAttr(func.getName()),
|
func.getLoc(), builder.getSymbolRefAttr(func.getName()),
|
||||||
builder.getI32IntegerAttr(func.getNumArguments()),
|
builder.getI32IntegerAttr(func.getNumArguments()),
|
||||||
builder.getI32IntegerAttr(func.getNumResults()));
|
builder.getI32IntegerAttr(func.getNumResults()));
|
||||||
|
|
||||||
if (!expressibleWithNpcomprtABI(func.getType()))
|
if (!expressibleWithRefbackrtABI(func.getType()))
|
||||||
return func.emitError() << "func not expressible with npcomprt ABI";
|
return func.emitError() << "func not expressible with refbackrt ABI";
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -81,8 +81,8 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
|
rewriter.replaceOpWithNewOp<refbackrt::GlobalOp>(op, op.sym_name(),
|
||||||
op.value());
|
op.value());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -96,7 +96,7 @@ public:
|
||||||
LogicalResult
|
LogicalResult
|
||||||
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
auto abiMemref = rewriter.create<refbackrt::GetGlobalOp>(
|
||||||
op.getLoc(), getABIMemrefType(op.getType()), op.global());
|
op.getLoc(), getABIMemrefType(op.getType()), op.global());
|
||||||
// Cast back to the original type.
|
// Cast back to the original type.
|
||||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
||||||
|
@ -113,22 +113,22 @@ public:
|
||||||
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
|
||||||
ConversionPatternRewriter &rewriter) const override {
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
AssertOp::Adaptor adaptor(operands);
|
AssertOp::Adaptor adaptor(operands);
|
||||||
// The npcomprt runtime function aborts if the argument is true, rather than
|
// The refbackrt runtime function aborts if the argument is true, rather
|
||||||
// when it is false as an `assert` does. So negate the predicate (by xor'ing
|
// than when it is false as an `assert` does. So negate the predicate (by
|
||||||
// with 1).
|
// xor'ing with 1).
|
||||||
auto c1 = rewriter.create<ConstantOp>(
|
auto c1 = rewriter.create<ConstantOp>(
|
||||||
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
|
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
|
||||||
APInt(/*numBits=*/1, /*val=*/1)));
|
APInt(/*numBits=*/1, /*val=*/1)));
|
||||||
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
|
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
|
||||||
rewriter.replaceOpWithNewOp<npcomprt::AbortIfOp>(op, assertFailed,
|
rewriter.replaceOpWithNewOp<refbackrt::AbortIfOp>(op, assertFailed,
|
||||||
op.msgAttr());
|
op.msgAttr());
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// At ABI bondaries, use !npcomprt.tensor instead of memref.
|
// At ABI bondaries, use !refbackrt.tensor instead of memref.
|
||||||
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
|
||||||
public:
|
public:
|
||||||
using OpConversionPattern::OpConversionPattern;
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
@ -159,7 +159,7 @@ public:
|
||||||
for (auto newAndOldArg :
|
for (auto newAndOldArg :
|
||||||
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
|
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
|
||||||
std::tie(newArg, oldArg) = newAndOldArg;
|
std::tie(newArg, oldArg) = newAndOldArg;
|
||||||
auto abiMemref = rewriter.create<npcomprt::ToMemrefOp>(
|
auto abiMemref = rewriter.create<refbackrt::ToMemrefOp>(
|
||||||
op.getLoc(), getABIMemrefType(oldArg.getType()), newArg);
|
op.getLoc(), getABIMemrefType(oldArg.getType()), newArg);
|
||||||
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref,
|
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref,
|
||||||
oldArg.getType());
|
oldArg.getType());
|
||||||
|
@ -172,7 +172,7 @@ public:
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// At the return ABI boundaries, convert to !npcomprt.tensor type.
|
// At the return ABI boundaries, convert to !refbackrt.tensor type.
|
||||||
// This pattern is needed to trigger the type conversion mechanics to do a
|
// This pattern is needed to trigger the type conversion mechanics to do a
|
||||||
// target materialization.
|
// target materialization.
|
||||||
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
|
class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
|
||||||
|
@ -193,20 +193,20 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
||||||
TypeConverter typeConverter;
|
TypeConverter typeConverter;
|
||||||
typeConverter.addConversion([](Type type) { return type; });
|
typeConverter.addConversion([](Type type) { return type; });
|
||||||
typeConverter.addConversion([](MemRefType type) {
|
typeConverter.addConversion([](MemRefType type) {
|
||||||
return npcomprt::TensorType::get(type.getContext());
|
return refbackrt::TensorType::get(type.getContext());
|
||||||
});
|
});
|
||||||
typeConverter.addTargetMaterialization(
|
typeConverter.addTargetMaterialization(
|
||||||
[](OpBuilder &builder, npcomprt::TensorType type, ValueRange inputs,
|
[](OpBuilder &builder, refbackrt::TensorType type, ValueRange inputs,
|
||||||
Location loc) -> Value {
|
Location loc) -> Value {
|
||||||
assert(inputs.size() == 1);
|
assert(inputs.size() == 1);
|
||||||
auto abiMemref = builder.create<MemRefCastOp>(
|
auto abiMemref = builder.create<MemRefCastOp>(
|
||||||
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
|
loc, inputs[0], getABIMemrefType(inputs[0].getType()));
|
||||||
return builder.create<npcomprt::FromMemrefOp>(loc, type, abiMemref);
|
return builder.create<refbackrt::FromMemrefOp>(loc, type, abiMemref);
|
||||||
});
|
});
|
||||||
|
|
||||||
OwningRewritePatternList patterns;
|
OwningRewritePatternList patterns;
|
||||||
ConversionTarget target(*context);
|
ConversionTarget target(*context);
|
||||||
target.addLegalDialect<npcomprt::NpcomprtDialect>();
|
target.addLegalDialect<refbackrt::RefbackrtDialect>();
|
||||||
target.addLegalDialect<StandardOpsDialect>();
|
target.addLegalDialect<StandardOpsDialect>();
|
||||||
|
|
||||||
patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
|
patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
|
||||||
|
@ -230,10 +230,11 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// This pass lowers the public ABI of the module to the primitives exposed by
|
// This pass lowers the public ABI of the module to the primitives exposed by
|
||||||
// the npcomprt dialect.
|
// the refbackrt dialect.
|
||||||
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
class LowerToRefbackrtABI
|
||||||
|
: public LowerToRefbackrtABIBase<LowerToRefbackrtABI> {
|
||||||
void getDependentDialects(DialectRegistry ®istry) const override {
|
void getDependentDialects(DialectRegistry ®istry) const override {
|
||||||
registry.insert<npcomprt::NpcomprtDialect>();
|
registry.insert<refbackrt::RefbackrtDialect>();
|
||||||
}
|
}
|
||||||
|
|
||||||
void runOnOperation() override {
|
void runOnOperation() override {
|
||||||
|
@ -254,6 +255,6 @@ class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<ModuleOp>>
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
mlir::NPCOMP::createLowerToNpcomprtABIPass() {
|
mlir::NPCOMP::createLowerToRefbackrtABIPass() {
|
||||||
return std::make_unique<LowerToNpcomprtABI>();
|
return std::make_unique<LowerToRefbackrtABI>();
|
||||||
}
|
}
|
|
@ -292,12 +292,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
|
||||||
pm.addPass(createLowerToCFGPass());
|
pm.addPass(createLowerToCFGPass());
|
||||||
|
|
||||||
// Convert functions signatures and other constructs that interface with the
|
// Convert functions signatures and other constructs that interface with the
|
||||||
// runtime to the `npcomprt` dialect.
|
// runtime to the `refbackrt` dialect.
|
||||||
pm.addPass(createLowerToNpcomprtABIPass());
|
pm.addPass(createLowerToRefbackrtABIPass());
|
||||||
|
|
||||||
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass
|
||||||
// which reuses the upstream patterns and gives us a place to add our own
|
// which reuses the upstream patterns and gives us a place to add our own
|
||||||
// patterns for our own custom ops like the npcomprt ops.
|
// patterns for our own custom ops like the refbackrt ops.
|
||||||
pm.addPass(createLowerToLLVMPass());
|
pm.addPass(createLowerToLLVMPass());
|
||||||
|
|
||||||
// Although LLVM will clean everything up eventually, for the sake of IR
|
// Although LLVM will clean everything up eventually, for the sake of IR
|
||||||
|
|
|
@ -7,7 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
|
||||||
)
|
)
|
||||||
|
|
||||||
# The library that users link against, defining basic interactions with an
|
# The library that users link against, defining basic interactions with an
|
||||||
# npcomprt module and the relevant data structures.
|
# refbackrt module and the relevant data structures.
|
||||||
add_mlir_library(NPCOMPRuntime
|
add_mlir_library(NPCOMPRuntime
|
||||||
Runtime.cpp
|
Runtime.cpp
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
namespace npcomprt {
|
namespace refbackrt {
|
||||||
|
|
||||||
// All arguments are packed into this type-erased form for being invoked. See
|
// All arguments are packed into this type-erased form for being invoked. See
|
||||||
// LowerToLLVM.cpp for more details.
|
// LowerToLLVM.cpp for more details.
|
||||||
|
@ -55,6 +55,6 @@ struct GlobalDescriptor {
|
||||||
void *data;
|
void *data;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace npcomprt
|
} // namespace refbackrt
|
||||||
|
|
||||||
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
#include "CompilerDataStructures.h"
|
#include "CompilerDataStructures.h"
|
||||||
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
#include "npcomp/RefBackend/Runtime/UserAPI.h"
|
||||||
|
|
||||||
using namespace npcomprt;
|
using namespace refbackrt;
|
||||||
|
|
||||||
extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
|
extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
|
||||||
if (b) {
|
if (b) {
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
#include "CompilerDataStructures.h"
|
#include "CompilerDataStructures.h"
|
||||||
|
|
||||||
using namespace npcomprt;
|
using namespace refbackrt;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Tensor
|
// Tensor
|
||||||
|
@ -29,7 +29,7 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::int32_t npcomprt::getElementTypeByteSize(ElementType type) {
|
std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case ElementType::F32:
|
case ElementType::F32:
|
||||||
return 4;
|
return 4;
|
||||||
|
@ -85,9 +85,9 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
void npcomprt::invoke(ModuleDescriptor *moduleDescriptor,
|
void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
|
StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
|
||||||
MutableArrayRef<Ref<Tensor>> outputs) {
|
MutableArrayRef<Ref<Tensor>> outputs) {
|
||||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||||
assert(descriptor && "unknown function name");
|
assert(descriptor && "unknown function name");
|
||||||
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
|
||||||
|
@ -104,7 +104,7 @@ void npcomprt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
for (int i = 0, e = outputs.size(); i < e; i++)
|
for (int i = 0, e = outputs.size(); i < e; i++)
|
||||||
outputTensorPtrs[i] = static_cast<Tensor *>(packedOutputs[i]);
|
outputTensorPtrs[i] = static_cast<Tensor *>(packedOutputs[i]);
|
||||||
// TODO: Actually manage refcounts inside the compiler.
|
// TODO: Actually manage refcounts inside the compiler.
|
||||||
// Right now, we only pass around npcomprt.tensor's in trivial ways on ABI
|
// Right now, we only pass around refbackrt.tensor's in trivial ways on ABI
|
||||||
// boundaries, so the following contract of the compiler-generated code works:
|
// boundaries, so the following contract of the compiler-generated code works:
|
||||||
// - input tensors are never retained or released
|
// - input tensors are never retained or released
|
||||||
// - output tensors always have refcount 0. Hence the next line here is
|
// - output tensors always have refcount 0. Hence the next line here is
|
||||||
|
@ -113,9 +113,9 @@ void npcomprt::invoke(ModuleDescriptor *moduleDescriptor,
|
||||||
outputs[i] = Ref<Tensor>(outputTensorPtrs[i]);
|
outputs[i] = Ref<Tensor>(outputTensorPtrs[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
LogicalResult npcomprt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
|
||||||
StringRef functionName,
|
StringRef functionName,
|
||||||
FunctionMetadata &outMetadata) {
|
FunctionMetadata &outMetadata) {
|
||||||
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
|
||||||
if (!descriptor)
|
if (!descriptor)
|
||||||
return failure();
|
return failure();
|
||||||
|
|
|
@ -1,43 +0,0 @@
|
||||||
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
|
||||||
// expected-error @+1 {{must reference a valid func}}
|
|
||||||
npcomprt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
|
||||||
// expected-error @+1 {{must agree on number of inputs}}
|
|
||||||
npcomprt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
|
||||||
}
|
|
||||||
func @f() { return }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
|
||||||
// expected-error @+1 {{must agree on number of outputs}}
|
|
||||||
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
|
||||||
}
|
|
||||||
func @f() { return }
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
|
||||||
|
|
||||||
func @f() {
|
|
||||||
// expected-error @+1 {{must reference a valid npcomprt.global}}
|
|
||||||
npcomprt.get_global @nonexistent_symbol : memref<*xf32>
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// -----
|
|
||||||
|
|
||||||
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
|
||||||
|
|
||||||
func @f() {
|
|
||||||
// expected-error @+1 {{inconsistent with element type of global}}
|
|
||||||
npcomprt.get_global @g : memref<*xi8>
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,20 +0,0 @@
|
||||||
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
|
||||||
|
|
||||||
// CHECK: npcomprt.module_metadata
|
|
||||||
npcomprt.module_metadata {
|
|
||||||
// CHECK: npcomprt.func_metadata
|
|
||||||
npcomprt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: func @f
|
|
||||||
// CHECK-SAME: !npcomprt.tensor
|
|
||||||
func @f(%arg0: !npcomprt.tensor) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32>
|
|
||||||
npcomprt.global @g dense<0.0> : tensor<10xf32>
|
|
||||||
func @uses_global() {
|
|
||||||
npcomprt.get_global @g : memref<*xf32>
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
|
||||||
|
|
||||||
|
refbackrt.module_metadata {
|
||||||
|
// expected-error @+1 {{must reference a valid func}}
|
||||||
|
refbackrt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
refbackrt.module_metadata {
|
||||||
|
// expected-error @+1 {{must agree on number of inputs}}
|
||||||
|
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||||
|
}
|
||||||
|
func @f() { return }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
refbackrt.module_metadata {
|
||||||
|
// expected-error @+1 {{must agree on number of outputs}}
|
||||||
|
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
||||||
|
}
|
||||||
|
func @f() { return }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
refbackrt.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{must reference a valid refbackrt.global}}
|
||||||
|
refbackrt.get_global @nonexistent_symbol : memref<*xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
refbackrt.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{inconsistent with element type of global}}
|
||||||
|
refbackrt.get_global @g : memref<*xi8>
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK: refbackrt.module_metadata
|
||||||
|
refbackrt.module_metadata {
|
||||||
|
// CHECK: refbackrt.func_metadata
|
||||||
|
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @f
|
||||||
|
// CHECK-SAME: !refbackrt.tensor
|
||||||
|
func @f(%arg0: !refbackrt.tensor) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: refbackrt.global @g dense<0.0{{.*}}> : tensor<10xf32>
|
||||||
|
refbackrt.global @g dense<0.0> : tensor<10xf32>
|
||||||
|
func @uses_global() {
|
||||||
|
refbackrt.get_global @g : memref<*xf32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -5,17 +5,17 @@
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float>
|
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float>
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32>
|
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32>
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> {
|
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> {
|
||||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm.ptr<array<1 x i32>>
|
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__refbackrt_global_extents_g : !llvm.ptr<array<1 x i32>>
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32>
|
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32>
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm.ptr<array<3 x float>>
|
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__refbackrt_global_data_buffer_g : !llvm.ptr<array<3 x float>>
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8>
|
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||||
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
|
||||||
|
@ -44,15 +44,15 @@
|
||||||
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
npcomprt.global @g dense<7.000000e+00> : tensor<3xf32>
|
refbackrt.global @g dense<7.000000e+00> : tensor<3xf32>
|
||||||
func @calls_get_global() -> memref<*xf32> {
|
func @calls_get_global() -> memref<*xf32> {
|
||||||
%0 = npcomprt.get_global @g : memref<*xf32>
|
%0 = refbackrt.get_global @g : memref<*xf32>
|
||||||
return %0 : memref<*xf32>
|
return %0 : memref<*xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
|
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float>
|
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float>
|
||||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32>
|
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32>
|
||||||
npcomprt.global @g dense<7.0> : tensor<f32>
|
refbackrt.global @g dense<7.0> : tensor<f32>
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_identity(
|
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_identity(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
@ -28,7 +28,7 @@
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr<array<8 x i8>>
|
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr<array<8 x i8>>
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<8 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<8 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_identity : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_identity : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
@ -48,8 +48,8 @@
|
||||||
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
refbackrt.module_metadata {
|
||||||
npcomprt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32}
|
refbackrt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,15 +57,15 @@ npcomprt.module_metadata {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
|
||||||
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
|
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
|
||||||
return %arg0 : !npcomprt.tensor
|
return %arg0 : !refbackrt.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// Test input/output arg marshaling.
|
// Test input/output arg marshaling.
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results2(
|
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
@ -86,7 +86,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results1(
|
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
@ -101,7 +101,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results0(
|
// CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
|
||||||
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
|
||||||
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
|
||||||
|
@ -127,7 +127,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
|
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
|
||||||
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||||
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
@ -139,7 +139,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
|
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
|
||||||
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||||
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
@ -151,7 +151,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
|
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
|
||||||
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
|
||||||
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
|
||||||
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
|
||||||
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
@ -171,10 +171,10 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
|
|
||||||
npcomprt.module_metadata {
|
refbackrt.module_metadata {
|
||||||
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
|
||||||
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
|
||||||
npcomprt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
|
refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ npcomprt.module_metadata {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @inputs1results0(%arg0: !npcomprt.tensor) {
|
func @inputs1results0(%arg0: !refbackrt.tensor) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -191,8 +191,8 @@ func @inputs1results0(%arg0: !npcomprt.tensor) {
|
||||||
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
|
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
|
||||||
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
|
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @inputs1results1(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
func @inputs1results1(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
|
||||||
return %arg0 : !npcomprt.tensor
|
return %arg0 : !refbackrt.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK-LABEL: llvm.func @inputs1results2(
|
// CHECK-LABEL: llvm.func @inputs1results2(
|
||||||
|
@ -202,8 +202,8 @@ func @inputs1results1(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
|
||||||
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr<i8>, ptr<i8>)>
|
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr<i8>, ptr<i8>)>
|
||||||
// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr<i8>, ptr<i8>)>
|
// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr<i8>, ptr<i8>)>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @inputs1results2(%arg0: !npcomprt.tensor) -> (!npcomprt.tensor, !npcomprt.tensor) {
|
func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackrt.tensor) {
|
||||||
return %arg0, %arg0 : !npcomprt.tensor, !npcomprt.tensor
|
return %arg0, %arg0 : !refbackrt.tensor, !refbackrt.tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -226,7 +226,7 @@ func @inputs1results2(%arg0: !npcomprt.tensor) -> (!npcomprt.tensor, !npcomprt.t
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
|
|
||||||
func @calls_abort_if(%arg0: i1) {
|
func @calls_abort_if(%arg0: i1) {
|
||||||
npcomprt.abort_if %arg0, "msg"
|
refbackrt.abort_if %arg0, "msg"
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -235,8 +235,8 @@ func @calls_abort_if(%arg0: i1) {
|
||||||
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
|
||||||
// CHECK: llvm.return
|
// CHECK: llvm.return
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @calls_to_memref(%arg0: !npcomprt.tensor) {
|
func @calls_to_memref(%arg0: !refbackrt.tensor) {
|
||||||
%0 = npcomprt.to_memref %arg0 : memref<*xf32>
|
%0 = refbackrt.to_memref %arg0 : memref<*xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -251,7 +251,7 @@ func @calls_to_memref(%arg0: !npcomprt.tensor) {
|
||||||
// CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
// CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
|
||||||
// CHECK: llvm.return %[[VAL_7]] : !llvm.ptr<i8>
|
// CHECK: llvm.return %[[VAL_7]] : !llvm.ptr<i8>
|
||||||
// CHECK: }
|
// CHECK: }
|
||||||
func @calls_from_memref(%arg0: memref<*xf32>) -> !npcomprt.tensor {
|
func @calls_from_memref(%arg0: memref<*xf32>) -> !refbackrt.tensor {
|
||||||
%0 = npcomprt.from_memref %arg0 : memref<*xf32>
|
%0 = refbackrt.from_memref %arg0 : memref<*xf32>
|
||||||
return %0 : !npcomprt.tensor
|
return %0 : !refbackrt.tensor
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
// RUN: npcomp-opt -lower-to-npcomprt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt -lower-to-refbackrt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
// Test module metadata.
|
// Test module metadata.
|
||||||
|
|
||||||
// CHECK: npcomprt.module_metadata
|
// CHECK: refbackrt.module_metadata
|
||||||
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
|
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
|
||||||
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
|
// CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
|
||||||
|
|
||||||
// This function only exists to test its metadata above.
|
// This function only exists to test its metadata above.
|
||||||
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
|
||||||
|
@ -20,12 +20,12 @@ func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>)
|
||||||
|
|
||||||
// Test ABI conversions.
|
// Test ABI conversions.
|
||||||
|
|
||||||
// CHECK-LABEL: func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor
|
// CHECK-LABEL: func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
|
||||||
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
// The argument materialization.
|
// The argument materialization.
|
||||||
// In this test case, these go unused since, as described below, the new
|
// In this test case, these go unused since, as described below, the new
|
||||||
// argument value is seen immediately by the return op for some reason.
|
// argument value is seen immediately by the return op for some reason.
|
||||||
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
|
||||||
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
|
|
||||||
// TODO: Why do these target materializations not happen in this particular
|
// TODO: Why do these target materializations not happen in this particular
|
||||||
|
@ -34,7 +34,7 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
// rather than the result of replaceUsesOfBlockArgument from
|
// rather than the result of replaceUsesOfBlockArgument from
|
||||||
// FuncOpSignatureConversion
|
// FuncOpSignatureConversion
|
||||||
// Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32>
|
// Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32>
|
||||||
// Cxxxx-NEXT: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
// Cxxxx-NEXT: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||||
// Cxxxx-NEXT: return %[[RET]]
|
// Cxxxx-NEXT: return %[[RET]]
|
||||||
|
|
||||||
// CHECK-NEXT: return %arg0
|
// CHECK-NEXT: return %arg0
|
||||||
|
@ -44,9 +44,9 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @use_of_arg(%arg0: !npcomprt.tensor)
|
// CHECK-LABEL: func @use_of_arg(%arg0: !refbackrt.tensor)
|
||||||
func @use_of_arg(%arg0: memref<?xf32>) {
|
func @use_of_arg(%arg0: memref<?xf32>) {
|
||||||
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
|
||||||
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
%c0 = constant 0 : index
|
%c0 = constant 0 : index
|
||||||
%0 = dim %arg0, %c0 : memref<?xf32>
|
%0 = dim %arg0, %c0 : memref<?xf32>
|
||||||
|
@ -57,32 +57,32 @@ func @use_of_arg(%arg0: memref<?xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @multiple_blocks(%arg0: !npcomprt.tensor) -> !npcomprt.tensor
|
// CHECK-LABEL: func @multiple_blocks(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
|
||||||
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
|
||||||
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32>
|
// CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
|
||||||
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
|
||||||
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
|
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
|
||||||
br ^bb1(%arg0: memref<?xf32>)
|
br ^bb1(%arg0: memref<?xf32>)
|
||||||
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
|
||||||
^bb1(%bbarg: memref<?xf32>):
|
^bb1(%bbarg: memref<?xf32>):
|
||||||
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
|
||||||
// CHECK-NEXT: %[[OUTABIMEMREF:.*]] = npcomprt.from_memref %[[OUTMEMREF]] : memref<*xf32>
|
// CHECK-NEXT: %[[OUTABIMEMREF:.*]] = refbackrt.from_memref %[[OUTMEMREF]] : memref<*xf32>
|
||||||
// CHECK-NEXT: return %[[OUTABIMEMREF]] : !npcomprt.tensor
|
// CHECK-NEXT: return %[[OUTABIMEMREF]] : !refbackrt.tensor
|
||||||
return %bbarg : memref<?xf32>
|
return %bbarg : memref<?xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
||||||
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
// CHECK: refbackrt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||||
refback.global @g dense<7.0> : tensor<10xf32>
|
refback.global @g dense<7.0> : tensor<10xf32>
|
||||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
// CHECK-LABEL: func @gets_global() -> !refbackrt.tensor
|
||||||
func @gets_global() -> memref<10xf32> {
|
func @gets_global() -> memref<10xf32> {
|
||||||
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
// CHECK: %[[GMEMREF:.*]] = refbackrt.get_global @g : memref<*xf32>
|
||||||
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||||
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||||
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
// CHECK: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
|
||||||
// CHECK: return %[[RET]] : !npcomprt.tensor
|
// CHECK: return %[[RET]] : !refbackrt.tensor
|
||||||
%0 = refback.get_global_memref @g : memref<10xf32>
|
%0 = refback.get_global_memref @g : memref<10xf32>
|
||||||
return %0 : memref<10xf32>
|
return %0 : memref<10xf32>
|
||||||
}
|
}
|
||||||
|
@ -91,7 +91,7 @@ func @gets_global() -> memref<10xf32> {
|
||||||
|
|
||||||
// Test diagnostics.
|
// Test diagnostics.
|
||||||
|
|
||||||
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
// expected-error @+1 {{func not expressible with refbackrt ABI}}
|
||||||
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -73,32 +73,30 @@ invokeJITModuleWithATenTensors(npcomp::JITModule &jitModule,
|
||||||
std::vector<at::TensorArg> tensorArgs;
|
std::vector<at::TensorArg> tensorArgs;
|
||||||
for (auto arg : llvm::enumerate(args))
|
for (auto arg : llvm::enumerate(args))
|
||||||
tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index()));
|
tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index()));
|
||||||
at::CheckedFrom c = "converting to npcomprt::Tensor";
|
at::CheckedFrom c = "converting to refbackrt::Tensor";
|
||||||
for (auto &tensorArg : tensorArgs)
|
for (auto &tensorArg : tensorArgs)
|
||||||
at::checkScalarType(c, tensorArg, at::ScalarType::Float);
|
at::checkScalarType(c, tensorArg, at::ScalarType::Float);
|
||||||
at::checkAllContiguous(c, tensorArgs);
|
at::checkAllContiguous(c, tensorArgs);
|
||||||
|
|
||||||
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> npcomprtInputs;
|
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> refbackInputs;
|
||||||
for (at::Tensor arg : args) {
|
for (at::Tensor arg : args) {
|
||||||
SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end());
|
SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end());
|
||||||
float *data = arg.storage().data<float>();
|
float *data = arg.storage().data<float>();
|
||||||
// This does a deep copy of the data. Let's see if it shows up on the
|
// This does a deep copy of the data. Let's see if it shows up on the
|
||||||
// profile.
|
// profile.
|
||||||
npcomprtInputs.push_back(npcomprt::Tensor::create(
|
refbackInputs.push_back(refbackrt::Tensor::create(
|
||||||
npcomprt::ArrayRef<int32_t>(extents.data(), extents.size()),
|
refbackrt::ArrayRef<int32_t>(extents.data(), extents.size()),
|
||||||
npcomprt::ElementType::F32, data));
|
refbackrt::ElementType::F32, data));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Invoke the RefBackend function.
|
// Invoke the RefBackend function.
|
||||||
// TODO: The mishmash of terminology "npcomprt", "refback", "npcomp" in this
|
auto expectedOutputs = jitModule.invoke(invokeFunction, refbackInputs);
|
||||||
// file is getting out of hand.
|
|
||||||
auto expectedOutputs = jitModule.invoke(invokeFunction, npcomprtInputs);
|
|
||||||
if (!expectedOutputs)
|
if (!expectedOutputs)
|
||||||
return expectedOutputs.takeError();
|
return expectedOutputs.takeError();
|
||||||
auto npcomprtOutputs = std::move(*expectedOutputs);
|
auto refbackrtOutputs = std::move(*expectedOutputs);
|
||||||
|
|
||||||
std::vector<at::Tensor> results;
|
std::vector<at::Tensor> results;
|
||||||
for (auto output : npcomprtOutputs) {
|
for (auto output : refbackrtOutputs) {
|
||||||
std::vector<int64_t> sizes(output->getExtents().data(),
|
std::vector<int64_t> sizes(output->getExtents().data(),
|
||||||
output->getExtents().data() +
|
output->getExtents().data() +
|
||||||
output->getExtents().size());
|
output->getExtents().size());
|
||||||
|
|
|
@ -35,7 +35,7 @@ static Error make_string_error(const Twine &message) {
|
||||||
llvm::inconvertibleErrorCode());
|
llvm::inconvertibleErrorCode());
|
||||||
}
|
}
|
||||||
|
|
||||||
static Expected<npcomprt::Ref<npcomprt::Tensor>>
|
static Expected<refbackrt::Ref<refbackrt::Tensor>>
|
||||||
convertAttrToTensor(Attribute attr) {
|
convertAttrToTensor(Attribute attr) {
|
||||||
auto type = attr.getType().dyn_cast<RankedTensorType>();
|
auto type = attr.getType().dyn_cast<RankedTensorType>();
|
||||||
if (!type)
|
if (!type)
|
||||||
|
@ -47,18 +47,18 @@ convertAttrToTensor(Attribute attr) {
|
||||||
if (elementType.isF32()) {
|
if (elementType.isF32()) {
|
||||||
auto values = llvm::to_vector<100>(llvm::map_range(
|
auto values = llvm::to_vector<100>(llvm::map_range(
|
||||||
denseFp, [](APFloat f) { return f.convertToFloat(); }));
|
denseFp, [](APFloat f) { return f.convertToFloat(); }));
|
||||||
return npcomprt::Tensor::create(
|
return refbackrt::Tensor::create(
|
||||||
npcomprt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
|
||||||
npcomprt::ElementType::F32, static_cast<void *>(values.data()));
|
refbackrt::ElementType::F32, static_cast<void *>(values.data()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return make_string_error("unhandled argument");
|
return make_string_error("unhandled argument");
|
||||||
}
|
}
|
||||||
|
|
||||||
static Expected<SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>>
|
static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
|
||||||
createInputs(ArrayRef<StringRef> argValues) {
|
createInputs(ArrayRef<StringRef> argValues) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> ret;
|
SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret;
|
||||||
for (auto argValue : argValues) {
|
for (auto argValue : argValues) {
|
||||||
auto attr = parseAttribute(argValue, &context);
|
auto attr = parseAttribute(argValue, &context);
|
||||||
if (!attr)
|
if (!attr)
|
||||||
|
@ -71,14 +71,14 @@ createInputs(ArrayRef<StringRef> argValues) {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
static Type convertToMLIRType(npcomprt::ElementType type, Builder &builder) {
|
static Type convertToMLIRType(refbackrt::ElementType type, Builder &builder) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case npcomprt::ElementType::F32:
|
case refbackrt::ElementType::F32:
|
||||||
return builder.getF32Type();
|
return builder.getF32Type();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static RankedTensorType getCorrespondingMLIRTensorType(npcomprt::Tensor &tensor,
|
static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor,
|
||||||
Builder &builder) {
|
Builder &builder) {
|
||||||
auto elementType = convertToMLIRType(tensor.getElementType(), builder);
|
auto elementType = convertToMLIRType(tensor.getElementType(), builder);
|
||||||
SmallVector<int64_t, 6> extents;
|
SmallVector<int64_t, 6> extents;
|
||||||
|
@ -87,11 +87,11 @@ static RankedTensorType getCorrespondingMLIRTensorType(npcomprt::Tensor &tensor,
|
||||||
return RankedTensorType::get(extents, elementType);
|
return RankedTensorType::get(extents, elementType);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Attribute convertToMLIRAttribute(npcomprt::Tensor &tensor,
|
static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor,
|
||||||
Builder &builder) {
|
Builder &builder) {
|
||||||
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
|
||||||
switch (tensor.getElementType()) {
|
switch (tensor.getElementType()) {
|
||||||
case npcomprt::ElementType::F32: {
|
case refbackrt::ElementType::F32: {
|
||||||
SmallVector<float, 100> values;
|
SmallVector<float, 100> values;
|
||||||
auto *basePtr = tensor.getData<float>();
|
auto *basePtr = tensor.getData<float>();
|
||||||
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
for (int i = 0, e = type.getNumElements(); i < e; i++)
|
||||||
|
@ -101,14 +101,14 @@ static Attribute convertToMLIRAttribute(npcomprt::Tensor &tensor,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printOutput(npcomprt::Tensor &tensor, llvm::raw_ostream &os) {
|
static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
|
||||||
MLIRContext context;
|
MLIRContext context;
|
||||||
Builder builder(&context);
|
Builder builder(&context);
|
||||||
auto attr = convertToMLIRAttribute(tensor, builder);
|
auto attr = convertToMLIRAttribute(tensor, builder);
|
||||||
attr.print(os);
|
attr.print(os);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void printOutputs(ArrayRef<npcomprt::Ref<npcomprt::Tensor>> outputs,
|
static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs,
|
||||||
llvm::raw_ostream &os) {
|
llvm::raw_ostream &os) {
|
||||||
for (auto output : llvm::enumerate(outputs)) {
|
for (auto output : llvm::enumerate(outputs)) {
|
||||||
os << "output #" << output.index() << ": ";
|
os << "output #" << output.index() << ": ";
|
||||||
|
|
Loading…
Reference in New Issue