[RefBackend] Rename Npcomprt dialect to Refbackrt.

pull/73/head
Sean Silva 2020-10-07 17:12:52 -07:00
parent 83ad70ef54
commit bf99a82832
41 changed files with 352 additions and 352 deletions

View File

@ -1,8 +1,8 @@
add_subdirectory(ATen) add_subdirectory(ATen)
add_subdirectory(Basicpy) add_subdirectory(Basicpy)
add_subdirectory(Npcomprt)
add_subdirectory(Numpy) add_subdirectory(Numpy)
add_subdirectory(RefBackend) add_subdirectory(RefBackend)
add_subdirectory(Refbackrt)
add_subdirectory(TCF) add_subdirectory(TCF)
add_subdirectory(TCP) add_subdirectory(TCP)
add_subdirectory(Torch) add_subdirectory(Torch)

View File

@ -1 +0,0 @@
add_mlir_dialect(NpcomprtOps npcomprt)

View File

@ -0,0 +1 @@
add_mlir_dialect(RefbackrtOps refbackrt)

View File

@ -6,31 +6,31 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMPRT_BASE #ifndef REFBACKRT_BASE
#define NPCOMPRT_BASE #define REFBACKRT_BASE
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def Npcomprt_Dialect : Dialect { def Refbackrt_Dialect : Dialect {
let name = "npcomprt"; let name = "refbackrt";
let cppNamespace = "::mlir::NPCOMP::npcomprt"; let cppNamespace = "::mlir::NPCOMP::refbackrt";
let description = [{ let description = [{
The `npcomprt` dialect is the IR manifestation for interaction with the The `refbackrt` dialect is the IR manifestation for interaction with the
npcomp runtime. It primarily serves as a layer that enapsulates the data reference backend runtime. It primarily serves as a layer that enapsulates the
structures and functions available in the runtime, and faciliates data structures and functions available in the runtime, and faciliates
conversion to those conventions, such as by providing utilities for being conversion to those conventions, such as by providing utilities for being
lowered to the llvm dialect. lowered to the llvm dialect.
}]; }];
} }
def Npcomprt_Tensor def Refbackrt_Tensor
: DialectType< : DialectType<
Npcomprt_Dialect, Refbackrt_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">, CPred<"$_self.isa<::mlir::NPCOMP::refbackrt::TensorType>()">,
"npcomprt.tensor">, "refbackrt.tensor">,
BuildableType< BuildableType<
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> { "$_builder.getType<::mlir::NPCOMP::refbackrt::TensorType>()"> {
let typeDescription = [{The runtime type that represents a buffer.}]; let typeDescription = [{The runtime type that represents a buffer.}];
} }
#endif // #ifndef NPCOMPRT_BASE #endif // #ifndef REFBACKRT_BASE

View File

@ -6,14 +6,14 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H #ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
#define NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H #define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H
#include "mlir/IR/Dialect.h" #include "mlir/IR/Dialect.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {
namespace npcomprt { namespace refbackrt {
class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> { class TensorType : public Type::TypeBase<TensorType, Type, TypeStorage> {
public: public:
@ -22,10 +22,10 @@ public:
static TensorType get(MLIRContext *context) { return Base::get(context); } static TensorType get(MLIRContext *context) { return Base::get(context); }
}; };
} // namespace npcomprt } // namespace refbackrt
} // namespace NPCOMP } // namespace NPCOMP
} // namespace mlir } // namespace mlir
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOpsDialect.h.inc" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOpsDialect.h.inc"
#endif // NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTDIALECT_H #endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTDIALECT_H

View File

@ -6,8 +6,8 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H #ifndef NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#define NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H #define NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
@ -15,6 +15,6 @@
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h.inc" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h.inc"
#endif // NPCOMP_DIALECT_NPCOMPRT_IR_NPCOMPRTOPS_H #endif // NPCOMP_DIALECT_REFBACKRT_IR_REFBACKRTOPS_H

View File

@ -6,37 +6,37 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#ifndef NPCOMPRT_OPS #ifndef REFBACKRT_OPS
#define NPCOMPRT_OPS #define REFBACKRT_OPS
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td" include "npcomp/Dialect/Refbackrt/IR/RefbackrtBase.td"
include "mlir/IR/SymbolInterfaces.td" include "mlir/IR/SymbolInterfaces.td"
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []> class Refbackrt_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Npcomprt_Dialect, mnemonic, traits> { : Op<Refbackrt_Dialect, mnemonic, traits> {
} }
def Npcomprt_ToMemrefOp : Npcomprt_Op<"to_memref"> { def Refbackrt_ToMemrefOp : Refbackrt_Op<"to_memref"> {
let summary = "Gets a memref descriptor from a tensor"; let summary = "Gets a memref descriptor from a tensor";
let description = [{ let description = [{
Gets a memref descriptor from a tensor. Gets a memref descriptor from a tensor.
}]; }];
let arguments = (ins Npcomprt_Tensor:$tensor); let arguments = (ins Refbackrt_Tensor:$tensor);
let results = (outs AnyUnrankedMemRef:$memref); let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let assemblyFormat = "$tensor attr-dict `:` type($memref)";
} }
def Npcomprt_FromMemrefOp : Npcomprt_Op<"from_memref"> { def Refbackrt_FromMemrefOp : Refbackrt_Op<"from_memref"> {
let summary = "Converts a memref descriptor to a tensor"; let summary = "Converts a memref descriptor to a tensor";
let description = [{ let description = [{
Copies the data from a memref into a new tensor. Copies the data from a memref into a new tensor.
}]; }];
let arguments = (ins AnyUnrankedMemRef:$memref); let arguments = (ins AnyUnrankedMemRef:$memref);
let results = (outs Npcomprt_Tensor:$tensor); let results = (outs Refbackrt_Tensor:$tensor);
let assemblyFormat = "$memref attr-dict `:` type($memref)"; let assemblyFormat = "$memref attr-dict `:` type($memref)";
} }
def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> { def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
let summary = "Aborts if the predicate is true"; let summary = "Aborts if the predicate is true";
let description = [{ let description = [{
Aborts if the predicate is true. Aborts if the predicate is true.
@ -46,7 +46,7 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
let assemblyFormat = "$pred `,` $msg attr-dict"; let assemblyFormat = "$pred `,` $msg attr-dict";
} }
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> { def Refbackrt_GlobalOp : Refbackrt_Op<"global", [Symbol]> {
let summary = "Represents a global variable"; let summary = "Represents a global variable";
let description = [{ let description = [{
Represents a global variable. Represents a global variable.
@ -61,19 +61,19 @@ def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> { def Refbackrt_GetGlobalOp : Refbackrt_Op<"get_global"> {
let summary = "Obtain a rank-erased memref pointing at the given global"; let summary = "Obtain a rank-erased memref pointing at the given global";
let description = [{ let description = [{
Obtain a rank-erased memref pointing at the given global. Obtain a rank-erased memref pointing at the given global.
TODO: As we define the runtime layer better, we should have fewer TODO: As we define the runtime layer better, we should have fewer
entry points that return memrefs, or at least have a clearer separation entry points that return memrefs, or at least have a clearer separation
between the "memref world" and the "npcomprt world". between the "memref world" and the "refbackrt world".
Something like forming IREE dispatch regions seems to be the missing thing: Something like forming IREE dispatch regions seems to be the missing thing:
- Everything inside the dispatch regions gets things marshaled from the - Everything inside the dispatch regions gets things marshaled from the
runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way. runtime (flow/hal/refbackrt) layer to/from memrefs in a clear way.
- Everything outside the dispatch regions purely uses the runtime - Everything outside the dispatch regions purely uses the runtime
(flow/hal/npcomprt) data structures. (flow/hal/refbackrt) data structures.
Globals should be one of the things that are purely runtime data structures, Globals should be one of the things that are purely runtime data structures,
rather than using memrefs. For now, using memrefs is simpler though. rather than using memrefs. For now, using memrefs is simpler though.
}]; }];
@ -83,12 +83,12 @@ def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
let verifier = "return ::verify$cppClass(*this);"; let verifier = "return ::verify$cppClass(*this);";
} }
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [ def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp"> SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
]> { ]> {
let summary = "Global metadata for the module"; let summary = "Global metadata for the module";
let description = [{ let description = [{
This op contains a region containing npcomprt.func_metadata ops, This op contains a region containing refbackrt.func_metadata ops,
which give information about the functions in the module. This allows which give information about the functions in the module. This allows
the module to be introspected when it is loaded, such as looking up the module to be introspected when it is loaded, such as looking up
functions. functions.
@ -110,8 +110,8 @@ def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
def Npcomprt_ModuleMetadataTerminatorOp def Refbackrt_ModuleMetadataTerminatorOp
: Npcomprt_Op<"module_metadata_terminator", : Refbackrt_Op<"module_metadata_terminator",
[Terminator, HasParent<"ModuleMetadataOp">]> { [Terminator, HasParent<"ModuleMetadataOp">]> {
let summary = "Implicit terminator for ModuleMetadataOp's region"; let summary = "Implicit terminator for ModuleMetadataOp's region";
let arguments = (ins); let arguments = (ins);
@ -119,8 +119,8 @@ def Npcomprt_ModuleMetadataTerminatorOp
let assemblyFormat = "attr-dict"; let assemblyFormat = "attr-dict";
} }
def Npcomprt_FuncMetadataOp def Refbackrt_FuncMetadataOp
: Npcomprt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> { : Refbackrt_Op<"func_metadata", [HasParent<"ModuleMetadataOp">]> {
let summary = "Runtime metadata for a single func"; let summary = "Runtime metadata for a single func";
let description = [{ let description = [{
Runtime metadata for a single func. Runtime metadata for a single func.
@ -138,4 +138,4 @@ def Npcomprt_FuncMetadataOp
let verifier = [{ return ::verify(*this); }]; let verifier = [{ return ::verify(*this); }];
} }
#endif // #ifndef NPCOMPRT_OPS #endif // #ifndef REFBACKRT_OPS

View File

@ -24,7 +24,7 @@ class PassManager;
} // namespace mlir } // namespace mlir
namespace npcomp { namespace npcomp {
// Wrapper around npcomprt data structures and a JITted module, facilitating // Wrapper around refbackrt data structures and a JITted module, facilitating
// interaction. // interaction.
class JITModule { class JITModule {
public: public:
@ -40,14 +40,14 @@ public:
fromCompiledModule(mlir::ModuleOp module, fromCompiledModule(mlir::ModuleOp module,
llvm::ArrayRef<llvm::StringRef> sharedLibs); llvm::ArrayRef<llvm::StringRef> sharedLibs);
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>> llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
invoke(llvm::StringRef functionName, invoke(llvm::StringRef functionName,
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs); llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs);
private: private:
JITModule(); JITModule();
std::unique_ptr<mlir::ExecutionEngine> engine; std::unique_ptr<mlir::ExecutionEngine> engine;
npcomprt::ModuleDescriptor *descriptor; refbackrt::ModuleDescriptor *descriptor;
}; };
} // namespace npcomp } // namespace npcomp

View File

@ -52,10 +52,10 @@ def LowerStructuralToMemref :
let constructor = "mlir::NPCOMP::createLowerStructuralToMemrefPass()"; let constructor = "mlir::NPCOMP::createLowerStructuralToMemrefPass()";
} }
def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> { def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
let summary = "Lower constructs requiring runtime support to `npcomprt`"; let summary = "Lower constructs requiring runtime support to `refbackrt`";
let description = [{ let description = [{
We have a specialized dialect `npcomprt` which models our runtime's data We have a specialized dialect `refbackrt` which models our runtime's data
structures, and function signatures (and presumably eventually, other structures, and function signatures (and presumably eventually, other
ABI boundaries like external calls if we ever support it) will be ABI boundaries like external calls if we ever support it) will be
converted. converted.
@ -65,7 +65,7 @@ def LowerToNpcomprtABI : Pass<"lower-to-npcomprt-abi", "ModuleOp"> {
- globals - globals
- error handling - error handling
}]; }];
let constructor = "mlir::NPCOMP::createLowerToNpcomprtABIPass()"; let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";
} }
def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> { def LowerAllocMemRefOps : Pass<"lower-alloc-memref-ops", "FuncOp"> {

View File

@ -34,7 +34,7 @@ createLowerConstantTensorsToMemrefPass();
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass(); std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
std::unique_ptr<OperationPass<ModuleOp>> createLowerToNpcomprtABIPass(); std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();
std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass(); std::unique_ptr<OperationPass<FuncOp>> createLowerAllocMemRefOpsPass();

View File

@ -1,4 +1,4 @@
RefBackendRt (namespace `refbackrt`) is the runtime support library for the Refbackrt (namespace `refbackrt`) is the runtime support library for the
RefBackend backend. It is best practice to keep compiler and runtime code RefBackend backend. It is best practice to keep compiler and runtime code
totally firewalled. totally firewalled.

View File

@ -19,7 +19,7 @@
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
namespace npcomprt { namespace refbackrt {
class StringRef { class StringRef {
public: public:
StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){}; StringRef(const char *ptr, std::size_t length) : ptr(ptr), length(length){};
@ -94,6 +94,6 @@ inline bool failed(LogicalResult result) {
return result.value == LogicalResult::Failure; return result.value == LogicalResult::Failure;
} }
} // namespace npcomprt } // namespace refbackrt
#endif // NPCOMP_RUNTIME_SUPPORT_H #endif // NPCOMP_RUNTIME_SUPPORT_H

View File

@ -25,7 +25,7 @@
#include <atomic> #include <atomic>
#include <cstdlib> #include <cstdlib>
namespace npcomprt { namespace refbackrt {
// Reference-counted handle to a type with a `refCount` member. // Reference-counted handle to a type with a `refCount` member.
template <typename T> class Ref { template <typename T> class Ref {
@ -178,6 +178,6 @@ LogicalResult getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName, StringRef functionName,
FunctionMetadata &outMetadata); FunctionMetadata &outMetadata);
} // namespace npcomprt } // namespace refbackrt
#endif // NPCOMP_RUNTIME_USERAPI_H #endif // NPCOMP_RUNTIME_USERAPI_H

View File

@ -22,8 +22,8 @@ using llvm::Twine;
using mlir::PyModuleOp; using mlir::PyModuleOp;
using mlir::PyPassManager; using mlir::PyPassManager;
using npcomp::JITModule; using npcomp::JITModule;
using npcomprt::Ref; using refbackrt::Ref;
using npcomprt::Tensor; using refbackrt::Tensor;
template <typename T> template <typename T>
static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) { static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
@ -37,10 +37,10 @@ static T checkError(llvm::Expected<T> &&expected, Twine banner = {}) {
throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str()); throw py::raisePyError(PyExc_RuntimeError, errorMessage.c_str());
} }
static npcomprt::ElementType static refbackrt::ElementType
mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) { mapBufferFormatToElementType(const std::string &format, py::ssize_t itemSize) {
if (format == "f") if (format == "f")
return npcomprt::ElementType::F32; return refbackrt::ElementType::F32;
std::string message("unsupported buffer format: "); std::string message("unsupported buffer format: ");
message.append(format); message.append(format);
@ -61,7 +61,7 @@ static Ref<Tensor> copyBufferToTensor(py::buffer buffer) {
// TODO: Switch Tensor extents to ssize_t for efficiency. // TODO: Switch Tensor extents to ssize_t for efficiency.
SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end()); SmallVector<std::int32_t, 4> extents(info.shape.begin(), info.shape.end());
return Tensor::create( return Tensor::create(
npcomprt::ArrayRef<std::int32_t>(extents.data(), extents.size()), refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
elementType, info.ptr); elementType, info.ptr);
} }
@ -73,7 +73,7 @@ py::array wrapTensorAsArray(Ref<Tensor> tensor) {
const char *format; const char *format;
switch (tensor->getElementType()) { switch (tensor->getElementType()) {
case npcomprt::ElementType::F32: case refbackrt::ElementType::F32:
format = "f"; format = "f";
break; break;
default: default:

View File

@ -43,7 +43,7 @@ add_mlir_library(NPCOMPInitAll
NPCOMPTCPDialect NPCOMPTCPDialect
NPCOMPTCFDialect NPCOMPTCFDialect
NPCOMPTorchDialect NPCOMPTorchDialect
NPCOMPNpcomprtDialect NPCOMPRefbackrtDialect
NPCOMPATenDialect NPCOMPATenDialect
NPCOMPBasicpyDialect NPCOMPBasicpyDialect
NPCOMPBasicpyPasses NPCOMPBasicpyPasses

View File

@ -1,8 +1,8 @@
add_subdirectory(ATen) add_subdirectory(ATen)
add_subdirectory(Basicpy) add_subdirectory(Basicpy)
add_subdirectory(Npcomprt)
add_subdirectory(Numpy) add_subdirectory(Numpy)
add_subdirectory(RefBackend) add_subdirectory(RefBackend)
add_subdirectory(Refbackrt)
add_subdirectory(TCF) add_subdirectory(TCF)
add_subdirectory(TCP) add_subdirectory(TCP)
add_subdirectory(Torch) add_subdirectory(Torch)

View File

@ -1,17 +0,0 @@
add_mlir_dialect_library(NPCOMPNpcomprtDialect
NpcomprtDialect.cpp
NpcomprtOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Npcomprt
DEPENDS
MLIRNpcomprtOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
)

View File

@ -0,0 +1,17 @@
add_mlir_dialect_library(NPCOMPRefbackrtDialect
RefbackrtDialect.cpp
RefbackrtOps.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Dialect/Refbackrt
DEPENDS
MLIRRefbackrtOpsIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRSupport
)

View File

@ -6,23 +6,23 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "mlir/IR/DialectImplementation.h" #include "mlir/IR/DialectImplementation.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "llvm/ADT/TypeSwitch.h" #include "llvm/ADT/TypeSwitch.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP::npcomprt; using namespace mlir::NPCOMP::refbackrt;
void NpcomprtDialect::initialize() { void RefbackrtDialect::initialize() {
addOperations< addOperations<
#define GET_OP_LIST #define GET_OP_LIST
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.cpp.inc" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"
>(); >();
addTypes<TensorType>(); addTypes<TensorType>();
} }
Type NpcomprtDialect::parseType(DialectAsmParser &parser) const { Type RefbackrtDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword; StringRef keyword;
if (parser.parseKeyword(&keyword)) if (parser.parseKeyword(&keyword))
return Type(); return Type();
@ -30,14 +30,14 @@ Type NpcomprtDialect::parseType(DialectAsmParser &parser) const {
if (keyword == "tensor") if (keyword == "tensor")
return TensorType::get(getContext()); return TensorType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown type in 'npcomprt' dialect: ") parser.emitError(parser.getNameLoc(), "unknown type in 'refbackrt' dialect: ")
<< keyword; << keyword;
return Type(); return Type();
} }
void NpcomprtDialect::printType(Type type, DialectAsmPrinter &os) const { void RefbackrtDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type) TypeSwitch<Type>(type)
.Case<NPCOMP::npcomprt::TensorType>([&](Type) { os << "tensor"; }) .Case<NPCOMP::refbackrt::TensorType>([&](Type) { os << "tensor"; })
.Default( .Default(
[&](Type) { llvm_unreachable("unexpected 'npcomprt' type kind"); }); [&](Type) { llvm_unreachable("unexpected 'refbackrt' type kind"); });
} }

View File

@ -6,22 +6,22 @@
// //
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP::npcomprt; using namespace mlir::NPCOMP::refbackrt;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// GlobalOp // GlobalOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) { static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "npcomprt.global "; p << "refbackrt.global ";
p.printSymbolName(op.sym_name()); p.printSymbolName(op.sym_name());
p << ' '; p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(), p.printOptionalAttrDictWithKeyword(op.getAttrs(),
@ -49,7 +49,7 @@ static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) { static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global()); auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global) if (!global)
return op.emitError() << "must reference a valid npcomprt.global"; return op.emitError() << "must reference a valid refbackrt.global";
auto globalType = global.value().getType(); auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>(); auto resultType = op.getType().cast<ShapedType>();
if (globalType.getElementType() != resultType.getElementType()) if (globalType.getElementType() != resultType.getElementType())
@ -62,7 +62,7 @@ static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) { static void printModuleMetadataOp(OpAsmPrinter &p, ModuleMetadataOp &op) {
p << "npcomprt.module_metadata"; p << "refbackrt.module_metadata";
p.printOptionalAttrDictWithKeyword(op.getAttrs()); p.printOptionalAttrDictWithKeyword(op.getAttrs());
p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false, p.printRegion(op.metadatas(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false); /*printBlockTerminators=*/false);
@ -99,4 +99,4 @@ static LogicalResult verify(FuncMetadataOp op) {
} }
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.cpp.inc" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.cpp.inc"

View File

@ -12,7 +12,7 @@
#include "npcomp/Dialect/ATen/ATenPasses.h" #include "npcomp/Dialect/ATen/ATenPasses.h"
#include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h" #include "npcomp/Dialect/Basicpy/IR/BasicpyDialect.h"
#include "npcomp/Dialect/Basicpy/Transforms/Passes.h" #include "npcomp/Dialect/Basicpy/Transforms/Passes.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Numpy/IR/NumpyDialect.h" #include "npcomp/Dialect/Numpy/IR/NumpyDialect.h"
#include "npcomp/Dialect/Numpy/Transforms/Passes.h" #include "npcomp/Dialect/Numpy/Transforms/Passes.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h" #include "npcomp/Dialect/RefBackend/IR/RefBackendDialect.h"
@ -73,7 +73,7 @@ void mlir::NPCOMP::registerAllDialects(mlir::DialectRegistry &registry) {
registry.insert<mlir::NPCOMP::aten::ATenDialect, registry.insert<mlir::NPCOMP::aten::ATenDialect,
Basicpy::BasicpyDialect, Basicpy::BasicpyDialect,
Numpy::NumpyDialect, Numpy::NumpyDialect,
npcomprt::NpcomprtDialect, refbackrt::RefbackrtDialect,
refback::RefBackendDialect, refback::RefBackendDialect,
tcf::TCFDialect, tcf::TCFDialect,
tcp::TCPDialect, tcp::TCPDialect,

View File

@ -5,7 +5,7 @@ add_mlir_library(NPCOMPRefBackend
BypassShapes.cpp BypassShapes.cpp
RefBackend.cpp RefBackend.cpp
LowerToLLVM.cpp LowerToLLVM.cpp
LowerToNpcomprtABI.cpp LowerToRefbackrtABI.cpp
TensorToMemref/LowerConstantTensorsToMemref.cpp TensorToMemref/LowerConstantTensorsToMemref.cpp
TensorToMemref/LowerShapedResultsToMemref.cpp TensorToMemref/LowerShapedResultsToMemref.cpp
TensorToMemref/LowerStdToMemref.cpp TensorToMemref/LowerStdToMemref.cpp

View File

@ -51,39 +51,40 @@ JITModule::fromCompiledModule(mlir::ModuleOp module,
if (!expectedAddress) if (!expectedAddress)
return expectedAddress.takeError(); return expectedAddress.takeError();
ret->descriptor = ret->descriptor =
reinterpret_cast<npcomprt::ModuleDescriptor *>(*expectedAddress); reinterpret_cast<refbackrt::ModuleDescriptor *>(*expectedAddress);
return std::move(ret); return std::move(ret);
} }
// Converter for bridging to npcomprt llvm-lookalike data structures. // Converter for bridging to refbackrt llvm-lookalike data structures.
static npcomprt::StringRef toNpcomprt(llvm::StringRef s) { static refbackrt::StringRef toRefbackrt(llvm::StringRef s) {
return npcomprt::StringRef(s.data(), s.size()); return refbackrt::StringRef(s.data(), s.size());
} }
template <typename T> template <typename T>
static npcomprt::ArrayRef<T> toNpcomprt(llvm::ArrayRef<T> a) { static refbackrt::ArrayRef<T> toRefbackrt(llvm::ArrayRef<T> a) {
return npcomprt::ArrayRef<T>(a.data(), a.size()); return refbackrt::ArrayRef<T>(a.data(), a.size());
} }
template <typename T> template <typename T>
static npcomprt::MutableArrayRef<T> toNpcomprt(llvm::MutableArrayRef<T> a) { static refbackrt::MutableArrayRef<T> toRefbackrt(llvm::MutableArrayRef<T> a) {
return npcomprt::MutableArrayRef<T>(a.data(), a.size()); return refbackrt::MutableArrayRef<T>(a.data(), a.size());
} }
llvm::Expected<llvm::SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>> llvm::Expected<llvm::SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
JITModule::invoke(llvm::StringRef functionName, JITModule::invoke(llvm::StringRef functionName,
llvm::ArrayRef<npcomprt::Ref<npcomprt::Tensor>> inputs) { llvm::ArrayRef<refbackrt::Ref<refbackrt::Tensor>> inputs) {
npcomprt::FunctionMetadata metadata; refbackrt::FunctionMetadata metadata;
if (npcomprt::failed(npcomprt::getMetadata( if (refbackrt::failed(refbackrt::getMetadata(
descriptor, toNpcomprt(functionName), metadata))) descriptor, toRefbackrt(functionName), metadata)))
return make_string_error("unknown function: " + Twine(functionName)); return make_string_error("unknown function: " + Twine(functionName));
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> outputs(metadata.numOutputs); SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> outputs(
metadata.numOutputs);
if (metadata.numInputs != static_cast<std::int32_t>(inputs.size())) if (metadata.numInputs != static_cast<std::int32_t>(inputs.size()))
return make_string_error("invoking '" + Twine(functionName) + return make_string_error("invoking '" + Twine(functionName) +
"': expected " + Twine(metadata.numInputs) + "': expected " + Twine(metadata.numInputs) +
" inputs"); " inputs");
npcomprt::invoke( refbackrt::invoke(
descriptor, toNpcomprt(functionName), toNpcomprt(inputs), descriptor, toRefbackrt(functionName), toRefbackrt(inputs),
toNpcomprt(llvm::makeMutableArrayRef(outputs.data(), outputs.size()))); toRefbackrt(llvm::makeMutableArrayRef(outputs.data(), outputs.size())));
return outputs; return outputs;
} }

View File

@ -15,8 +15,8 @@
#include "mlir/Dialect/StandardOps/Transforms/Passes.h" #include "mlir/Dialect/StandardOps/Transforms/Passes.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
@ -29,7 +29,7 @@ using mlir::LLVM::LLVMType;
// These correspond to the types in CompilerDataStructures.h // These correspond to the types in CompilerDataStructures.h
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Get the LLVMType for npcomprt::FuncDescriptor. // Get the LLVMType for refbackrt::FuncDescriptor.
static LLVMType getFuncDescriptorTy(MLIRContext *context) { static LLVMType getFuncDescriptorTy(MLIRContext *context) {
return LLVMType::getStructTy(context, { return LLVMType::getStructTy(context, {
// Name length. // Name length.
@ -45,7 +45,7 @@ static LLVMType getFuncDescriptorTy(MLIRContext *context) {
}); });
} }
// Get the LLVMType for npcomprt::ModuleDescriptor. // Get the LLVMType for refbackrt::ModuleDescriptor.
static LLVMType getModuleDescriptorTy(MLIRContext *context) { static LLVMType getModuleDescriptorTy(MLIRContext *context) {
return LLVMType::getStructTy(context, return LLVMType::getStructTy(context,
{ {
@ -56,7 +56,7 @@ static LLVMType getModuleDescriptorTy(MLIRContext *context) {
}); });
} }
// Get the LLVMType for npcomprt::GlobalDescriptor. // Get the LLVMType for refbackrt::GlobalDescriptor.
static LLVMType getGlobalDescriptorTy(MLIRContext *context) { static LLVMType getGlobalDescriptorTy(MLIRContext *context) {
return LLVMType::getStructTy( return LLVMType::getStructTy(
// std::int32_t numExtents; // std::int32_t numExtents;
@ -97,13 +97,13 @@ namespace {
// FromMemrefOp requires special handling so that the unranked memref descriptor // FromMemrefOp requires special handling so that the unranked memref descriptor
// gets passed as two separate arguments instead of as a struct. // gets passed as two separate arguments instead of as a struct.
class FromMemrefOpCompilerRuntimeLowering class FromMemrefOpCompilerRuntimeLowering
: public OpConversionPattern<npcomprt::FromMemrefOp> { : public OpConversionPattern<refbackrt::FromMemrefOp> {
public: public:
FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) FromMemrefOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<npcomprt::FromMemrefOp>(backingFunc.getContext()), : OpConversionPattern<refbackrt::FromMemrefOp>(backingFunc.getContext()),
backingFunc(backingFunc) {} backingFunc(backingFunc) {}
LogicalResult LogicalResult
matchAndRewrite(npcomprt::FromMemrefOp op, ArrayRef<Value> operands, matchAndRewrite(refbackrt::FromMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto structVal = operands[0]; auto structVal = operands[0];
Value rank = rewriter.create<LLVM::ExtractValueOp>( Value rank = rewriter.create<LLVM::ExtractValueOp>(
@ -124,17 +124,17 @@ public:
namespace { namespace {
class GetGlobalOpCompilerRuntimeLowering class GetGlobalOpCompilerRuntimeLowering
: public OpConversionPattern<npcomprt::GetGlobalOp> { : public OpConversionPattern<refbackrt::GetGlobalOp> {
public: public:
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<npcomprt::GetGlobalOp>(backingFunc.getContext()), : OpConversionPattern<refbackrt::GetGlobalOp>(backingFunc.getContext()),
backingFunc(backingFunc) {} backingFunc(backingFunc) {}
LogicalResult LogicalResult
matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef<Value> operands, matchAndRewrite(refbackrt::GetGlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
// It would be nice if we could use the constructor here that takes just the // It would be nice if we could use the constructor here that takes just the
// global, but keeping track of the converted llvm.mlir.global op that gets // global, but keeping track of the converted llvm.mlir.global op that gets
// created from the npcomprt.global while conversion is going on is a // created from the refbackrt.global while conversion is going on is a
// headache. // headache.
// //
// Instead, we rely on the symbol name being the same and the result type // Instead, we rely on the symbol name being the same and the result type
@ -178,15 +178,15 @@ static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
namespace { namespace {
class AbortIfOpCompilerRuntimeLowering class AbortIfOpCompilerRuntimeLowering
: public OpConversionPattern<npcomprt::AbortIfOp> { : public OpConversionPattern<refbackrt::AbortIfOp> {
public: public:
AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc) AbortIfOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<npcomprt::AbortIfOp>(backingFunc.getContext()), : OpConversionPattern<refbackrt::AbortIfOp>(backingFunc.getContext()),
backingFunc(backingFunc) {} backingFunc(backingFunc) {}
LogicalResult LogicalResult
matchAndRewrite(npcomprt::AbortIfOp op, ArrayRef<Value> operands, matchAndRewrite(refbackrt::AbortIfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
npcomprt::AbortIfOp::Adaptor adaptor(operands); refbackrt::AbortIfOp::Adaptor adaptor(operands);
auto *context = op.getContext(); auto *context = op.getContext();
// Create the global string, take its address, and gep to get an `i8*`. // Create the global string, take its address, and gep to get an `i8*`.
@ -207,7 +207,7 @@ public:
}; };
} // namespace } // namespace
// Create the LLVM runtime function backing the npcomprt op with name `name` // Create the LLVM runtime function backing the refbackrt op with name `name`
// and requiring `type`. // and requiring `type`.
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type, static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
OpBuilder &builder, OpBuilder &builder,
@ -242,12 +242,12 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
{ {
auto mlirFunctionType = builder.getFunctionType( auto mlirFunctionType = builder.getFunctionType(
{builder.getType<npcomprt::TensorType>()}, {builder.getType<refbackrt::TensorType>()},
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}); {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
LLVMType funcTy = convertFunctionType(mlirFunctionType); LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl( LLVMFuncOp toMemrefFunc = createCompilerRuntimeFuncDecl(
"to_memref", funcTy, builder, module.getLoc()); "to_memref", funcTy, builder, module.getLoc());
patterns.insert<TrivialCompilerRuntimeLowering<npcomprt::ToMemrefOp>>( patterns.insert<TrivialCompilerRuntimeLowering<refbackrt::ToMemrefOp>>(
toMemrefFunc); toMemrefFunc);
} }
@ -256,7 +256,7 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
// doesn't know its own dtype. // doesn't know its own dtype.
auto mlirFunctionType = builder.getFunctionType( auto mlirFunctionType = builder.getFunctionType(
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)}, {UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)},
{builder.getType<npcomprt::TensorType>()}); {builder.getType<refbackrt::TensorType>()});
LLVMType funcTy = convertFunctionType(mlirFunctionType); LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl( LLVMFuncOp fromMemrefFunc = createCompilerRuntimeFuncDecl(
"from_memref", funcTy, builder, module.getLoc()); "from_memref", funcTy, builder, module.getLoc());
@ -277,24 +277,24 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Lowering for npcomprt.global // Lowering for refbackrt.global
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
namespace { namespace {
class LowerNpcomprtGlobalOp : public OpConversionPattern<npcomprt::GlobalOp> { class LowerRefbackrtGlobalOp : public OpConversionPattern<refbackrt::GlobalOp> {
public: public:
explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter) explicit LowerRefbackrtGlobalOp(LLVMTypeConverter &typeConverter)
: OpConversionPattern<npcomprt::GlobalOp>(&typeConverter.getContext()), : OpConversionPattern<refbackrt::GlobalOp>(&typeConverter.getContext()),
typeConverter(typeConverter) {} typeConverter(typeConverter) {}
LogicalResult LogicalResult
matchAndRewrite(npcomprt::GlobalOp op, ArrayRef<Value> operands, matchAndRewrite(refbackrt::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext(); auto *context = rewriter.getContext();
auto globalDescriptorTy = getGlobalDescriptorTy(context); auto globalDescriptorTy = getGlobalDescriptorTy(context);
// Create the data buffer. // Create the data buffer.
auto dataBuffer = createGlobalForDenseElementsAttr( auto dataBuffer = createGlobalForDenseElementsAttr(
(Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(), (Twine("__refbackrt_global_data_buffer_") + op.sym_name()).str(),
op.value().cast<DenseElementsAttr>(), op, rewriter); op.value().cast<DenseElementsAttr>(), op, rewriter);
// Create the extents buffer. // Create the extents buffer.
@ -302,8 +302,8 @@ public:
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(), llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
[](int64_t i) -> int32_t { return i; }))); [](int64_t i) -> int32_t { return i; })));
auto extentsBuffer = createGlobalForDenseElementsAttr( auto extentsBuffer = createGlobalForDenseElementsAttr(
(Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32, (Twine("__refbackrt_global_extents_") + op.sym_name()).str(),
op, rewriter); extentsI32, op, rewriter);
// Create the GlobalDescriptor. // Create the GlobalDescriptor.
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>( auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
@ -352,7 +352,7 @@ public:
private: private:
// TODO: It feels like MLIR core should have better utilities for this. // TODO: It feels like MLIR core should have better utilities for this.
LLVM::GlobalOp createGlobalForDenseElementsAttr( LLVM::GlobalOp createGlobalForDenseElementsAttr(
StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op, StringRef symbolName, DenseElementsAttr elements, refbackrt::GlobalOp op,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto type = elements.getType().cast<ShapedType>(); auto type = elements.getType().cast<ShapedType>();
@ -384,7 +384,7 @@ private:
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements); /*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
} }
LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op, LLVMType getLLVMTypeForShapedType(ShapedType type, refbackrt::GlobalOp op,
ConversionPatternRewriter &rewriter) const { ConversionPatternRewriter &rewriter) const {
auto llvmType = auto llvmType =
typeConverter.convertType(type.getElementType()).cast<LLVMType>(); typeConverter.convertType(type.getElementType()).cast<LLVMType>();
@ -425,7 +425,7 @@ private:
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LLVM::GlobalOp static LLVM::GlobalOp
createFuncDescriptorArray(ArrayRef<npcomprt::FuncMetadataOp> funcMetadatas, createFuncDescriptorArray(ArrayRef<refbackrt::FuncMetadataOp> funcMetadatas,
OpBuilder &builder, Location loc) { OpBuilder &builder, Location loc) {
auto llvmI32Ty = LLVMType::getIntNTy(builder.getContext(), 32); auto llvmI32Ty = LLVMType::getIntNTy(builder.getContext(), 32);
@ -556,14 +556,14 @@ LLVM::GlobalOp createModuleDescriptor(LLVM::GlobalOp funcDescriptorArray,
namespace { namespace {
class LowerModuleMetadata class LowerModuleMetadata
: public OpConversionPattern<npcomprt::ModuleMetadataOp> { : public OpConversionPattern<refbackrt::ModuleMetadataOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
LogicalResult LogicalResult
matchAndRewrite(npcomprt::ModuleMetadataOp op, ArrayRef<Value> operands, matchAndRewrite(refbackrt::ModuleMetadataOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto funcMetadatas = auto funcMetadatas =
llvm::to_vector<6>(op.metadatas().getOps<npcomprt::FuncMetadataOp>()); llvm::to_vector<6>(op.metadatas().getOps<refbackrt::FuncMetadataOp>());
auto funcDescriptorArray = auto funcDescriptorArray =
createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc()); createFuncDescriptorArray(funcMetadatas, rewriter, op.getLoc());
auto moduleDescriptor = auto moduleDescriptor =
@ -646,7 +646,7 @@ static void storeWrapperResults(LLVM::CallOp callToWrapped, Value resultsPtrPtr,
// Construct a wrapper function. // Construct a wrapper function.
// For an externally visible function f(T1, T2) -> T3, T4, we create a // For an externally visible function f(T1, T2) -> T3, T4, we create a
// wrapper // wrapper
// __npcomprt_wrapper_f(void **inputs, void ** outputs) { // __refbackrt_wrapper_f(void **inputs, void ** outputs) {
// T3 t3; // T3 t3;
// T4 t4; // T4 t4;
// (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1])); // (t3, t4) = f(*cast<T1*>(inputs[0]), *cast<T2*>(inputs[1]));
@ -664,8 +664,8 @@ static LLVMFuncOp createWrapperFunc(LLVMFuncOp func) {
auto wrapperTy = LLVMType::getFunctionTy(LLVMType::getVoidTy(context), auto wrapperTy = LLVMType::getFunctionTy(LLVMType::getVoidTy(context),
{voidStarStarTy, voidStarStarTy}, {voidStarStarTy, voidStarStarTy},
/*isVarArg=*/false); /*isVarArg=*/false);
constexpr char kNpcomprtWrapperPrefix[] = "__npcomprt_wrapper_"; constexpr char kRefbackrtWrapperPrefix[] = "__refbackrt_wrapper_";
auto wrapperName = (Twine(kNpcomprtWrapperPrefix) + func.getName()).str(); auto wrapperName = (Twine(kRefbackrtWrapperPrefix) + func.getName()).str();
OpBuilder moduleBuilder(func.getParentRegion()); OpBuilder moduleBuilder(func.getParentRegion());
LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>( LLVMFuncOp wrapper = moduleBuilder.create<LLVMFuncOp>(
func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External); func.getLoc(), wrapperName, wrapperTy, LLVM::Linkage::External);
@ -693,8 +693,8 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
LLVMTypeConverter converter(context); LLVMTypeConverter converter(context);
// npcomprt::TensorType is passed as a `void*` in the ABI. // refbackrt::TensorType is passed as a `void*` in the ABI.
converter.addConversion([&](npcomprt::TensorType type) { converter.addConversion([&](refbackrt::TensorType type) {
return LLVMType::getInt8PtrTy(context); return LLVMType::getInt8PtrTy(context);
}); });
@ -706,7 +706,7 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
populateStdToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns);
patterns.insert<LowerModuleMetadata>(context); patterns.insert<LowerModuleMetadata>(context);
patterns.insert<LowerNpcomprtGlobalOp>(converter); patterns.insert<LowerRefbackrtGlobalOp>(converter);
// TODO: Move these "std to std" legalizations to their own pass if we grow // TODO: Move these "std to std" legalizations to their own pass if we grow
// lots of these patterns. // lots of these patterns.

View File

@ -14,8 +14,8 @@
#include "mlir/IR/Verifier.h" #include "mlir/IR/Verifier.h"
#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtDialect.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtOps.h" #include "npcomp/Dialect/Refbackrt/IR/RefbackrtOps.h"
#include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h" #include "npcomp/Dialect/RefBackend/IR/RefBackendOps.h"
using namespace mlir; using namespace mlir;
@ -32,10 +32,10 @@ static Type getABIMemrefType(Type type) {
// Creating module metadata. // Creating module metadata.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Returns true if the function signature can be expressed with the npcomprt // Returns true if the function signature can be expressed with the refbackrt
// ABI. // ABI.
static bool expressibleWithNpcomprtABI(FunctionType type) { static bool expressibleWithRefbackrtABI(FunctionType type) {
// Currently, only memref types can be exposed at npcomprt ABI boundaries. // Currently, only memref types can be exposed at refbackrt ABI boundaries.
return llvm::all_of( return llvm::all_of(
llvm::concat<const Type>(type.getInputs(), type.getResults()), llvm::concat<const Type>(type.getInputs(), type.getResults()),
[](Type t) { return t.isa<MemRefType>(); }); [](Type t) { return t.isa<MemRefType>(); });
@ -44,11 +44,11 @@ static bool expressibleWithNpcomprtABI(FunctionType type) {
static LogicalResult createModuleMetadata(ModuleOp module) { static LogicalResult createModuleMetadata(ModuleOp module) {
auto moduleMetadata = auto moduleMetadata =
OpBuilder::atBlockBegin(module.getBody()) OpBuilder::atBlockBegin(module.getBody())
.create<npcomprt::ModuleMetadataOp>(module.getLoc()); .create<refbackrt::ModuleMetadataOp>(module.getLoc());
moduleMetadata.metadatas().push_back(new Block); moduleMetadata.metadatas().push_back(new Block);
Block &metadatas = moduleMetadata.metadatas().front(); Block &metadatas = moduleMetadata.metadatas().front();
OpBuilder::atBlockEnd(&metadatas) OpBuilder::atBlockEnd(&metadatas)
.create<npcomprt::ModuleMetadataTerminatorOp>(module.getLoc()); .create<refbackrt::ModuleMetadataTerminatorOp>(module.getLoc());
SymbolTable symbolTable(module); SymbolTable symbolTable(module);
auto builder = OpBuilder::atBlockBegin(&metadatas); auto builder = OpBuilder::atBlockBegin(&metadatas);
@ -59,13 +59,13 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
} }
// TODO: Add richer information here such as expected shapes and element // TODO: Add richer information here such as expected shapes and element
// types. // types.
builder.create<npcomprt::FuncMetadataOp>( builder.create<refbackrt::FuncMetadataOp>(
func.getLoc(), builder.getSymbolRefAttr(func.getName()), func.getLoc(), builder.getSymbolRefAttr(func.getName()),
builder.getI32IntegerAttr(func.getNumArguments()), builder.getI32IntegerAttr(func.getNumArguments()),
builder.getI32IntegerAttr(func.getNumResults())); builder.getI32IntegerAttr(func.getNumResults()));
if (!expressibleWithNpcomprtABI(func.getType())) if (!expressibleWithRefbackrtABI(func.getType()))
return func.emitError() << "func not expressible with npcomprt ABI"; return func.emitError() << "func not expressible with refbackrt ABI";
} }
return success(); return success();
} }
@ -81,8 +81,8 @@ public:
LogicalResult LogicalResult
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands, matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(), rewriter.replaceOpWithNewOp<refbackrt::GlobalOp>(op, op.sym_name(),
op.value()); op.value());
return success(); return success();
} }
}; };
@ -96,7 +96,7 @@ public:
LogicalResult LogicalResult
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands, matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>( auto abiMemref = rewriter.create<refbackrt::GetGlobalOp>(
op.getLoc(), getABIMemrefType(op.getType()), op.global()); op.getLoc(), getABIMemrefType(op.getType()), op.global());
// Cast back to the original type. // Cast back to the original type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType()); rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
@ -113,22 +113,22 @@ public:
matchAndRewrite(AssertOp op, ArrayRef<Value> operands, matchAndRewrite(AssertOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override { ConversionPatternRewriter &rewriter) const override {
AssertOp::Adaptor adaptor(operands); AssertOp::Adaptor adaptor(operands);
// The npcomprt runtime function aborts if the argument is true, rather than // The refbackrt runtime function aborts if the argument is true, rather
// when it is false as an `assert` does. So negate the predicate (by xor'ing // than when it is false as an `assert` does. So negate the predicate (by
// with 1). // xor'ing with 1).
auto c1 = rewriter.create<ConstantOp>( auto c1 = rewriter.create<ConstantOp>(
op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(), op.getLoc(), rewriter.getIntegerAttr(rewriter.getI1Type(),
APInt(/*numBits=*/1, /*val=*/1))); APInt(/*numBits=*/1, /*val=*/1)));
Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1); Value assertFailed = rewriter.create<XOrOp>(op.getLoc(), adaptor.arg(), c1);
rewriter.replaceOpWithNewOp<npcomprt::AbortIfOp>(op, assertFailed, rewriter.replaceOpWithNewOp<refbackrt::AbortIfOp>(op, assertFailed,
op.msgAttr()); op.msgAttr());
return success(); return success();
} }
}; };
} // namespace } // namespace
namespace { namespace {
// At ABI bondaries, use !npcomprt.tensor instead of memref. // At ABI bondaries, use !refbackrt.tensor instead of memref.
class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> { class FuncOpSignatureConversion : public OpConversionPattern<FuncOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
@ -159,7 +159,7 @@ public:
for (auto newAndOldArg : for (auto newAndOldArg :
llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) { llvm::zip(newEntry.getArguments(), oldEntry.getArguments())) {
std::tie(newArg, oldArg) = newAndOldArg; std::tie(newArg, oldArg) = newAndOldArg;
auto abiMemref = rewriter.create<npcomprt::ToMemrefOp>( auto abiMemref = rewriter.create<refbackrt::ToMemrefOp>(
op.getLoc(), getABIMemrefType(oldArg.getType()), newArg); op.getLoc(), getABIMemrefType(oldArg.getType()), newArg);
auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref, auto memref = rewriter.create<MemRefCastOp>(op.getLoc(), abiMemref,
oldArg.getType()); oldArg.getType());
@ -172,7 +172,7 @@ public:
} // namespace } // namespace
namespace { namespace {
// At the return ABI boundaries, convert to !npcomprt.tensor type. // At the return ABI boundaries, convert to !refbackrt.tensor type.
// This pattern is needed to trigger the type conversion mechanics to do a // This pattern is needed to trigger the type conversion mechanics to do a
// target materialization. // target materialization.
class RewriteReturnOp : public OpConversionPattern<ReturnOp> { class RewriteReturnOp : public OpConversionPattern<ReturnOp> {
@ -193,20 +193,20 @@ static LogicalResult doDialectConversion(ModuleOp module) {
TypeConverter typeConverter; TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion([](MemRefType type) { typeConverter.addConversion([](MemRefType type) {
return npcomprt::TensorType::get(type.getContext()); return refbackrt::TensorType::get(type.getContext());
}); });
typeConverter.addTargetMaterialization( typeConverter.addTargetMaterialization(
[](OpBuilder &builder, npcomprt::TensorType type, ValueRange inputs, [](OpBuilder &builder, refbackrt::TensorType type, ValueRange inputs,
Location loc) -> Value { Location loc) -> Value {
assert(inputs.size() == 1); assert(inputs.size() == 1);
auto abiMemref = builder.create<MemRefCastOp>( auto abiMemref = builder.create<MemRefCastOp>(
loc, inputs[0], getABIMemrefType(inputs[0].getType())); loc, inputs[0], getABIMemrefType(inputs[0].getType()));
return builder.create<npcomprt::FromMemrefOp>(loc, type, abiMemref); return builder.create<refbackrt::FromMemrefOp>(loc, type, abiMemref);
}); });
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<npcomprt::NpcomprtDialect>(); target.addLegalDialect<refbackrt::RefbackrtDialect>();
target.addLegalDialect<StandardOpsDialect>(); target.addLegalDialect<StandardOpsDialect>();
patterns.insert<FuncOpSignatureConversion>(typeConverter, context); patterns.insert<FuncOpSignatureConversion>(typeConverter, context);
@ -230,10 +230,11 @@ static LogicalResult doDialectConversion(ModuleOp module) {
namespace { namespace {
// This pass lowers the public ABI of the module to the primitives exposed by // This pass lowers the public ABI of the module to the primitives exposed by
// the npcomprt dialect. // the refbackrt dialect.
class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> { class LowerToRefbackrtABI
: public LowerToRefbackrtABIBase<LowerToRefbackrtABI> {
void getDependentDialects(DialectRegistry &registry) const override { void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<npcomprt::NpcomprtDialect>(); registry.insert<refbackrt::RefbackrtDialect>();
} }
void runOnOperation() override { void runOnOperation() override {
@ -254,6 +255,6 @@ class LowerToNpcomprtABI : public LowerToNpcomprtABIBase<LowerToNpcomprtABI> {
} // namespace } // namespace
std::unique_ptr<OperationPass<ModuleOp>> std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createLowerToNpcomprtABIPass() { mlir::NPCOMP::createLowerToRefbackrtABIPass() {
return std::make_unique<LowerToNpcomprtABI>(); return std::make_unique<LowerToRefbackrtABI>();
} }

View File

@ -292,12 +292,12 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
pm.addPass(createLowerToCFGPass()); pm.addPass(createLowerToCFGPass());
// Convert functions signatures and other constructs that interface with the // Convert functions signatures and other constructs that interface with the
// runtime to the `npcomprt` dialect. // runtime to the `refbackrt` dialect.
pm.addPass(createLowerToNpcomprtABIPass()); pm.addPass(createLowerToRefbackrtABIPass());
// Finally, convert to LLVM dialect using our custom LowerToLLVM pass // Finally, convert to LLVM dialect using our custom LowerToLLVM pass
// which reuses the upstream patterns and gives us a place to add our own // which reuses the upstream patterns and gives us a place to add our own
// patterns for our own custom ops like the npcomprt ops. // patterns for our own custom ops like the refbackrt ops.
pm.addPass(createLowerToLLVMPass()); pm.addPass(createLowerToLLVMPass());
// Although LLVM will clean everything up eventually, for the sake of IR // Although LLVM will clean everything up eventually, for the sake of IR

View File

@ -7,7 +7,7 @@ set(LLVM_OPTIONAL_SOURCES
) )
# The library that users link against, defining basic interactions with an # The library that users link against, defining basic interactions with an
# npcomprt module and the relevant data structures. # refbackrt module and the relevant data structures.
add_mlir_library(NPCOMPRuntime add_mlir_library(NPCOMPRuntime
Runtime.cpp Runtime.cpp

View File

@ -17,7 +17,7 @@
#include <cstdint> #include <cstdint>
namespace npcomprt { namespace refbackrt {
// All arguments are packed into this type-erased form for being invoked. See // All arguments are packed into this type-erased form for being invoked. See
// LowerToLLVM.cpp for more details. // LowerToLLVM.cpp for more details.
@ -55,6 +55,6 @@ struct GlobalDescriptor {
void *data; void *data;
}; };
} // namespace npcomprt } // namespace refbackrt
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H #endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H

View File

@ -18,7 +18,7 @@
#include "CompilerDataStructures.h" #include "CompilerDataStructures.h"
#include "npcomp/RefBackend/Runtime/UserAPI.h" #include "npcomp/RefBackend/Runtime/UserAPI.h"
using namespace npcomprt; using namespace refbackrt;
extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) { extern "C" void __npcomp_compiler_rt_abort_if(bool b, const char *msg) {
if (b) { if (b) {

View File

@ -15,7 +15,7 @@
#include "CompilerDataStructures.h" #include "CompilerDataStructures.h"
using namespace npcomprt; using namespace refbackrt;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Tensor // Tensor
@ -29,7 +29,7 @@ static std::int32_t totalElements(ArrayRef<std::int32_t> extents) {
return ret; return ret;
} }
std::int32_t npcomprt::getElementTypeByteSize(ElementType type) { std::int32_t refbackrt::getElementTypeByteSize(ElementType type) {
switch (type) { switch (type) {
case ElementType::F32: case ElementType::F32:
return 4; return 4;
@ -85,9 +85,9 @@ static FuncDescriptor *getFuncDescriptor(ModuleDescriptor *moduleDescriptor,
return nullptr; return nullptr;
} }
void npcomprt::invoke(ModuleDescriptor *moduleDescriptor, void refbackrt::invoke(ModuleDescriptor *moduleDescriptor,
StringRef functionName, ArrayRef<Ref<Tensor>> inputs, StringRef functionName, ArrayRef<Ref<Tensor>> inputs,
MutableArrayRef<Ref<Tensor>> outputs) { MutableArrayRef<Ref<Tensor>> outputs) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
assert(descriptor && "unknown function name"); assert(descriptor && "unknown function name");
assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity"); assert(inputs.size() < kMaxArity && "number of inputs exceeds kMaxArity");
@ -104,7 +104,7 @@ void npcomprt::invoke(ModuleDescriptor *moduleDescriptor,
for (int i = 0, e = outputs.size(); i < e; i++) for (int i = 0, e = outputs.size(); i < e; i++)
outputTensorPtrs[i] = static_cast<Tensor *>(packedOutputs[i]); outputTensorPtrs[i] = static_cast<Tensor *>(packedOutputs[i]);
// TODO: Actually manage refcounts inside the compiler. // TODO: Actually manage refcounts inside the compiler.
// Right now, we only pass around npcomprt.tensor's in trivial ways on ABI // Right now, we only pass around refbackrt.tensor's in trivial ways on ABI
// boundaries, so the following contract of the compiler-generated code works: // boundaries, so the following contract of the compiler-generated code works:
// - input tensors are never retained or released // - input tensors are never retained or released
// - output tensors always have refcount 0. Hence the next line here is // - output tensors always have refcount 0. Hence the next line here is
@ -113,9 +113,9 @@ void npcomprt::invoke(ModuleDescriptor *moduleDescriptor,
outputs[i] = Ref<Tensor>(outputTensorPtrs[i]); outputs[i] = Ref<Tensor>(outputTensorPtrs[i]);
} }
LogicalResult npcomprt::getMetadata(ModuleDescriptor *moduleDescriptor, LogicalResult refbackrt::getMetadata(ModuleDescriptor *moduleDescriptor,
StringRef functionName, StringRef functionName,
FunctionMetadata &outMetadata) { FunctionMetadata &outMetadata) {
auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName); auto *descriptor = getFuncDescriptor(moduleDescriptor, functionName);
if (!descriptor) if (!descriptor)
return failure(); return failure();

View File

@ -1,43 +0,0 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
npcomprt.module_metadata {
// expected-error @+1 {{must reference a valid func}}
npcomprt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// -----
npcomprt.module_metadata {
// expected-error @+1 {{must agree on number of inputs}}
npcomprt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
func @f() { return }
// -----
npcomprt.module_metadata {
// expected-error @+1 {{must agree on number of outputs}}
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
}
func @f() { return }
// -----
npcomprt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid npcomprt.global}}
npcomprt.get_global @nonexistent_symbol : memref<*xf32>
return
}
// -----
npcomprt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
npcomprt.get_global @g : memref<*xi8>
return
}

View File

@ -1,20 +0,0 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
// CHECK: npcomprt.module_metadata
npcomprt.module_metadata {
// CHECK: npcomprt.func_metadata
npcomprt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// CHECK-LABEL: func @f
// CHECK-SAME: !npcomprt.tensor
func @f(%arg0: !npcomprt.tensor) {
return
}
// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32>
npcomprt.global @g dense<0.0> : tensor<10xf32>
func @uses_global() {
npcomprt.get_global @g : memref<*xf32>
return
}

View File

@ -0,0 +1,43 @@
// RUN: npcomp-opt <%s -split-input-file -verify-diagnostics
refbackrt.module_metadata {
// expected-error @+1 {{must reference a valid func}}
refbackrt.func_metadata {funcName = @g, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// -----
refbackrt.module_metadata {
// expected-error @+1 {{must agree on number of inputs}}
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
func @f() { return }
// -----
refbackrt.module_metadata {
// expected-error @+1 {{must agree on number of outputs}}
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
}
func @f() { return }
// -----
refbackrt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid refbackrt.global}}
refbackrt.get_global @nonexistent_symbol : memref<*xf32>
return
}
// -----
refbackrt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
refbackrt.get_global @g : memref<*xi8>
return
}

View File

@ -0,0 +1,20 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s --dump-input=fail
// CHECK: refbackrt.module_metadata
refbackrt.module_metadata {
// CHECK: refbackrt.func_metadata
refbackrt.func_metadata {funcName = @f, numInputs = 1 : i32, numOutputs = 0 : i32}
}
// CHECK-LABEL: func @f
// CHECK-SAME: !refbackrt.tensor
func @f(%arg0: !refbackrt.tensor) {
return
}
// CHECK-LABEL: refbackrt.global @g dense<0.0{{.*}}> : tensor<10xf32>
refbackrt.global @g dense<0.0> : tensor<10xf32>
func @uses_global() {
refbackrt.get_global @g : memref<*xf32>
return
}

View File

@ -5,17 +5,17 @@
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)> // CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8> // CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)> // CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float> // CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float>
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32> // CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32>
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> { // CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> // CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> // CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm.ptr<array<1 x i32>> // CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__refbackrt_global_extents_g : !llvm.ptr<array<1 x i32>>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32> // CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32>
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> // CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm.ptr<array<3 x float>> // CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__refbackrt_global_data_buffer_g : !llvm.ptr<array<3 x float>>
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8> // CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8>
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> // CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> // CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
@ -44,15 +44,15 @@
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)> // CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)> // CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: } // CHECK: }
npcomprt.global @g dense<7.000000e+00> : tensor<3xf32> refbackrt.global @g dense<7.000000e+00> : tensor<3xf32>
func @calls_get_global() -> memref<*xf32> { func @calls_get_global() -> memref<*xf32> {
%0 = npcomprt.get_global @g : memref<*xf32> %0 = refbackrt.get_global @g : memref<*xf32>
return %0 : memref<*xf32> return %0 : memref<*xf32>
} }
// ----- // -----
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy. // For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float> // CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float>
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32> // CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32>
npcomprt.global @g dense<7.0> : tensor<f32> refbackrt.global @g dense<7.0> : tensor<f32>

View File

@ -1,6 +1,6 @@
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail // RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_identity( // CHECK-LABEL: llvm.func @__refbackrt_wrapper_identity(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>, // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) { // CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
@ -28,7 +28,7 @@
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr<array<8 x i8>> // CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_identity : !llvm.ptr<array<8 x i8>>
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<8 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8> // CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<8 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_identity : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> // CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_identity : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8> // CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
@ -48,8 +48,8 @@
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> // CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: } // CHECK: }
npcomprt.module_metadata { refbackrt.module_metadata {
npcomprt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32} refbackrt.func_metadata {funcName = @identity, numInputs = 1 : i32, numOutputs = 1 : i32}
} }
@ -57,15 +57,15 @@ npcomprt.module_metadata {
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> { // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8> // CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
// CHECK: } // CHECK: }
func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor { func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
return %arg0 : !npcomprt.tensor return %arg0 : !refbackrt.tensor
} }
// ----- // -----
// Test input/output arg marshaling. // Test input/output arg marshaling.
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results2( // CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results2(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>, // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) { // CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
@ -86,7 +86,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: llvm.return // CHECK: llvm.return
// CHECK: } // CHECK: }
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results1( // CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results1(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>, // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) { // CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
@ -101,7 +101,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: llvm.return // CHECK: llvm.return
// CHECK: } // CHECK: }
// CHECK-LABEL: llvm.func @__npcomprt_wrapper_inputs1results0( // CHECK-LABEL: llvm.func @__refbackrt_wrapper_inputs1results0(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>, // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<ptr<i8>>,
// CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) { // CHECK-SAME: %[[VAL_1:.*]]: !llvm.ptr<ptr<i8>>) {
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(0 : i32) : !llvm.i32
@ -127,7 +127,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>> // CHECK: %[[VAL_4:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results0 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8> // CHECK: %[[VAL_5:.*]] = llvm.getelementptr %[[VAL_4]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_6:.*]] = llvm.insertvalue %[[VAL_5]], %[[VAL_3]][0 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> // CHECK: %[[VAL_7:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results0 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8> // CHECK: %[[VAL_8:.*]] = llvm.bitcast %[[VAL_7]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_9:.*]] = llvm.insertvalue %[[VAL_8]], %[[VAL_6]][0 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_10:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
@ -139,7 +139,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>> // CHECK: %[[VAL_16:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results1 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8> // CHECK: %[[VAL_17:.*]] = llvm.getelementptr %[[VAL_16]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_17]], %[[VAL_15]][1 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> // CHECK: %[[VAL_19:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results1 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8> // CHECK: %[[VAL_20:.*]] = llvm.bitcast %[[VAL_19]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_21:.*]] = llvm.insertvalue %[[VAL_20]], %[[VAL_18]][1 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_22:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
@ -151,7 +151,7 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>> // CHECK: %[[VAL_28:.*]] = llvm.mlir.addressof @__npcomp_internal_constant_inputs1results2 : !llvm.ptr<array<15 x i8>>
// CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8> // CHECK: %[[VAL_29:.*]] = llvm.getelementptr %[[VAL_28]]{{\[}}%[[VAL_1]], %[[VAL_1]]] : (!llvm.ptr<array<15 x i8>>, !llvm.i32, !llvm.i32) -> !llvm.ptr<i8>
// CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_30:.*]] = llvm.insertvalue %[[VAL_29]], %[[VAL_27]][2 : i32, 1 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__npcomprt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> // CHECK: %[[VAL_31:.*]] = llvm.mlir.addressof @__refbackrt_wrapper_inputs1results2 : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>>
// CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8> // CHECK: %[[VAL_32:.*]] = llvm.bitcast %[[VAL_31]] : !llvm.ptr<func<void (ptr<ptr<i8>>, ptr<ptr<i8>>)>> to !llvm.ptr<i8>
// CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> // CHECK: %[[VAL_33:.*]] = llvm.insertvalue %[[VAL_32]], %[[VAL_30]][2 : i32, 2 : i32] : !llvm.array<3 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>
// CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32 // CHECK: %[[VAL_34:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
@ -171,10 +171,10 @@ func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)> // CHECK: llvm.return %[[VAL_5]] : !llvm.struct<(i32, ptr<struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>>)>
// CHECK: } // CHECK: }
npcomprt.module_metadata { refbackrt.module_metadata {
npcomprt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32} refbackrt.func_metadata {funcName = @inputs1results0, numInputs = 1 : i32, numOutputs = 0 : i32}
npcomprt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32} refbackrt.func_metadata {funcName = @inputs1results1, numInputs = 1 : i32, numOutputs = 1 : i32}
npcomprt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32} refbackrt.func_metadata {funcName = @inputs1results2, numInputs = 1 : i32, numOutputs = 2 : i32}
} }
@ -183,7 +183,7 @@ npcomprt.module_metadata {
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) { // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) {
// CHECK: llvm.return // CHECK: llvm.return
// CHECK: } // CHECK: }
func @inputs1results0(%arg0: !npcomprt.tensor) { func @inputs1results0(%arg0: !refbackrt.tensor) {
return return
} }
@ -191,8 +191,8 @@ func @inputs1results0(%arg0: !npcomprt.tensor) {
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> { // CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<i8>) -> !llvm.ptr<i8> {
// CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8> // CHECK: llvm.return %[[VAL_0]] : !llvm.ptr<i8>
// CHECK: } // CHECK: }
func @inputs1results1(%arg0: !npcomprt.tensor) -> !npcomprt.tensor { func @inputs1results1(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
return %arg0 : !npcomprt.tensor return %arg0 : !refbackrt.tensor
} }
// CHECK-LABEL: llvm.func @inputs1results2( // CHECK-LABEL: llvm.func @inputs1results2(
@ -202,8 +202,8 @@ func @inputs1results1(%arg0: !npcomprt.tensor) -> !npcomprt.tensor {
// CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr<i8>, ptr<i8>)> // CHECK: %[[VAL_3:.*]] = llvm.insertvalue %[[VAL_0]], %[[VAL_2]][1] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr<i8>, ptr<i8>)> // CHECK: llvm.return %[[VAL_3]] : !llvm.struct<(ptr<i8>, ptr<i8>)>
// CHECK: } // CHECK: }
func @inputs1results2(%arg0: !npcomprt.tensor) -> (!npcomprt.tensor, !npcomprt.tensor) { func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackrt.tensor) {
return %arg0, %arg0 : !npcomprt.tensor, !npcomprt.tensor return %arg0, %arg0 : !refbackrt.tensor, !refbackrt.tensor
} }
@ -226,7 +226,7 @@ func @inputs1results2(%arg0: !npcomprt.tensor) -> (!npcomprt.tensor, !npcomprt.t
// CHECK: llvm.return // CHECK: llvm.return
func @calls_abort_if(%arg0: i1) { func @calls_abort_if(%arg0: i1) {
npcomprt.abort_if %arg0, "msg" refbackrt.abort_if %arg0, "msg"
return return
} }
@ -235,8 +235,8 @@ func @calls_abort_if(%arg0: i1) {
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)> // CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_to_memref(%[[VAL_0]]) : (!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.return // CHECK: llvm.return
// CHECK: } // CHECK: }
func @calls_to_memref(%arg0: !npcomprt.tensor) { func @calls_to_memref(%arg0: !refbackrt.tensor) {
%0 = npcomprt.to_memref %arg0 : memref<*xf32> %0 = refbackrt.to_memref %arg0 : memref<*xf32>
return return
} }
@ -251,7 +251,7 @@ func @calls_to_memref(%arg0: !npcomprt.tensor) {
// CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8> // CHECK: %[[VAL_7:.*]] = llvm.call @__npcomp_compiler_rt_from_memref(%[[VAL_5]], %[[VAL_6]]) : (!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.return %[[VAL_7]] : !llvm.ptr<i8> // CHECK: llvm.return %[[VAL_7]] : !llvm.ptr<i8>
// CHECK: } // CHECK: }
func @calls_from_memref(%arg0: memref<*xf32>) -> !npcomprt.tensor { func @calls_from_memref(%arg0: memref<*xf32>) -> !refbackrt.tensor {
%0 = npcomprt.from_memref %arg0 : memref<*xf32> %0 = refbackrt.from_memref %arg0 : memref<*xf32>
return %0 : !npcomprt.tensor return %0 : !refbackrt.tensor
} }

View File

@ -1,10 +1,10 @@
// RUN: npcomp-opt -lower-to-npcomprt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail // RUN: npcomp-opt -lower-to-refbackrt-abi -split-input-file -verify-diagnostics <%s | FileCheck %s --dump-input=fail
// Test module metadata. // Test module metadata.
// CHECK: npcomprt.module_metadata // CHECK: refbackrt.module_metadata
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32} // CHECK-NEXT: refbackrt.func_metadata {funcName = @f_2inputs_0outputs, numInputs = 2 : i32, numOutputs = 0 : i32}
// CHECK-NEXT: npcomprt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32} // CHECK-NEXT: refbackrt.func_metadata {funcName = @f_1input_2outputs, numInputs = 1 : i32, numOutputs = 2 : i32}
// This function only exists to test its metadata above. // This function only exists to test its metadata above.
func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) { func @f_2inputs_0outputs(%arg0: memref<?xf32>, %arg1: memref<?xf32>) {
@ -20,12 +20,12 @@ func @f_1input_2outputs(%arg0: memref<?xf32>) -> (memref<?xf32>, memref<?xf32>)
// Test ABI conversions. // Test ABI conversions.
// CHECK-LABEL: func @identity(%arg0: !npcomprt.tensor) -> !npcomprt.tensor // CHECK-LABEL: func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
func @identity(%arg0: memref<?xf32>) -> memref<?xf32> { func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// The argument materialization. // The argument materialization.
// In this test case, these go unused since, as described below, the new // In this test case, these go unused since, as described below, the new
// argument value is seen immediately by the return op for some reason. // argument value is seen immediately by the return op for some reason.
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32> // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32> // CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
// TODO: Why do these target materializations not happen in this particular // TODO: Why do these target materializations not happen in this particular
@ -34,7 +34,7 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// rather than the result of replaceUsesOfBlockArgument from // rather than the result of replaceUsesOfBlockArgument from
// FuncOpSignatureConversion // FuncOpSignatureConversion
// Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32> // Cxxxx-NEXT: %[[OUTABIMEMREF:.*]] = memref_cast %[[MEMREF]] : memref<?xf32> to memref<*xf32>
// Cxxxx-NEXT: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32> // Cxxxx-NEXT: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
// Cxxxx-NEXT: return %[[RET]] // Cxxxx-NEXT: return %[[RET]]
// CHECK-NEXT: return %arg0 // CHECK-NEXT: return %arg0
@ -44,9 +44,9 @@ func @identity(%arg0: memref<?xf32>) -> memref<?xf32> {
// ----- // -----
// CHECK-LABEL: func @use_of_arg(%arg0: !npcomprt.tensor) // CHECK-LABEL: func @use_of_arg(%arg0: !refbackrt.tensor)
func @use_of_arg(%arg0: memref<?xf32>) { func @use_of_arg(%arg0: memref<?xf32>) {
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32> // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32> // CHECK-NEXT: %[[MEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
%c0 = constant 0 : index %c0 = constant 0 : index
%0 = dim %arg0, %c0 : memref<?xf32> %0 = dim %arg0, %c0 : memref<?xf32>
@ -57,32 +57,32 @@ func @use_of_arg(%arg0: memref<?xf32>) {
// ----- // -----
// CHECK-LABEL: func @multiple_blocks(%arg0: !npcomprt.tensor) -> !npcomprt.tensor // CHECK-LABEL: func @multiple_blocks(%arg0: !refbackrt.tensor) -> !refbackrt.tensor
func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> { func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// CHECK-NEXT: %[[INABIMEMREF:.*]] = npcomprt.to_memref %arg0 : memref<*xf32> // CHECK-NEXT: %[[INABIMEMREF:.*]] = refbackrt.to_memref %arg0 : memref<*xf32>
// CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32> // CHECK-NEXT: %[[INMEMREF:.*]] = memref_cast %[[INABIMEMREF]] : memref<*xf32> to memref<?xf32>
// CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>) // CHECK-NEXT: br ^bb1(%[[INMEMREF]] : memref<?xf32>)
br ^bb1(%arg0: memref<?xf32>) br ^bb1(%arg0: memref<?xf32>)
// CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>): // CHECK-NEXT: ^bb1(%[[BBARG:.*]]: memref<?xf32>):
^bb1(%bbarg: memref<?xf32>): ^bb1(%bbarg: memref<?xf32>):
// CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32> // CHECK-NEXT: %[[OUTMEMREF:.*]] = memref_cast %[[BBARG]] : memref<?xf32> to memref<*xf32>
// CHECK-NEXT: %[[OUTABIMEMREF:.*]] = npcomprt.from_memref %[[OUTMEMREF]] : memref<*xf32> // CHECK-NEXT: %[[OUTABIMEMREF:.*]] = refbackrt.from_memref %[[OUTMEMREF]] : memref<*xf32>
// CHECK-NEXT: return %[[OUTABIMEMREF]] : !npcomprt.tensor // CHECK-NEXT: return %[[OUTABIMEMREF]] : !refbackrt.tensor
return %bbarg : memref<?xf32> return %bbarg : memref<?xf32>
} }
// ----- // -----
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32> // CHECK: refbackrt.global @g dense<7.000000e+00> : tensor<10xf32>
refback.global @g dense<7.0> : tensor<10xf32> refback.global @g dense<7.0> : tensor<10xf32>
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor // CHECK-LABEL: func @gets_global() -> !refbackrt.tensor
func @gets_global() -> memref<10xf32> { func @gets_global() -> memref<10xf32> {
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32> // CHECK: %[[GMEMREF:.*]] = refbackrt.get_global @g : memref<*xf32>
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32> // CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32> // CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[OUTABIMEMREF]] : memref<*xf32> // CHECK: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : !npcomprt.tensor // CHECK: return %[[RET]] : !refbackrt.tensor
%0 = refback.get_global_memref @g : memref<10xf32> %0 = refback.get_global_memref @g : memref<10xf32>
return %0 : memref<10xf32> return %0 : memref<10xf32>
} }
@ -91,7 +91,7 @@ func @gets_global() -> memref<10xf32> {
// Test diagnostics. // Test diagnostics.
// expected-error @+1 {{func not expressible with npcomprt ABI}} // expected-error @+1 {{func not expressible with refbackrt ABI}}
func @unhandled_abi_type_on_public_func(%arg0: i32) { func @unhandled_abi_type_on_public_func(%arg0: i32) {
return return
} }

View File

@ -73,32 +73,30 @@ invokeJITModuleWithATenTensors(npcomp::JITModule &jitModule,
std::vector<at::TensorArg> tensorArgs; std::vector<at::TensorArg> tensorArgs;
for (auto arg : llvm::enumerate(args)) for (auto arg : llvm::enumerate(args))
tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index())); tensorArgs.push_back(at::TensorArg(arg.value(), "arg", arg.index()));
at::CheckedFrom c = "converting to npcomprt::Tensor"; at::CheckedFrom c = "converting to refbackrt::Tensor";
for (auto &tensorArg : tensorArgs) for (auto &tensorArg : tensorArgs)
at::checkScalarType(c, tensorArg, at::ScalarType::Float); at::checkScalarType(c, tensorArg, at::ScalarType::Float);
at::checkAllContiguous(c, tensorArgs); at::checkAllContiguous(c, tensorArgs);
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> npcomprtInputs; SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> refbackInputs;
for (at::Tensor arg : args) { for (at::Tensor arg : args) {
SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end()); SmallVector<int32_t, 6> extents(arg.sizes().begin(), arg.sizes().end());
float *data = arg.storage().data<float>(); float *data = arg.storage().data<float>();
// This does a deep copy of the data. Let's see if it shows up on the // This does a deep copy of the data. Let's see if it shows up on the
// profile. // profile.
npcomprtInputs.push_back(npcomprt::Tensor::create( refbackInputs.push_back(refbackrt::Tensor::create(
npcomprt::ArrayRef<int32_t>(extents.data(), extents.size()), refbackrt::ArrayRef<int32_t>(extents.data(), extents.size()),
npcomprt::ElementType::F32, data)); refbackrt::ElementType::F32, data));
} }
// Invoke the RefBackend function. // Invoke the RefBackend function.
// TODO: The mishmash of terminology "npcomprt", "refback", "npcomp" in this auto expectedOutputs = jitModule.invoke(invokeFunction, refbackInputs);
// file is getting out of hand.
auto expectedOutputs = jitModule.invoke(invokeFunction, npcomprtInputs);
if (!expectedOutputs) if (!expectedOutputs)
return expectedOutputs.takeError(); return expectedOutputs.takeError();
auto npcomprtOutputs = std::move(*expectedOutputs); auto refbackrtOutputs = std::move(*expectedOutputs);
std::vector<at::Tensor> results; std::vector<at::Tensor> results;
for (auto output : npcomprtOutputs) { for (auto output : refbackrtOutputs) {
std::vector<int64_t> sizes(output->getExtents().data(), std::vector<int64_t> sizes(output->getExtents().data(),
output->getExtents().data() + output->getExtents().data() +
output->getExtents().size()); output->getExtents().size());

View File

@ -35,7 +35,7 @@ static Error make_string_error(const Twine &message) {
llvm::inconvertibleErrorCode()); llvm::inconvertibleErrorCode());
} }
static Expected<npcomprt::Ref<npcomprt::Tensor>> static Expected<refbackrt::Ref<refbackrt::Tensor>>
convertAttrToTensor(Attribute attr) { convertAttrToTensor(Attribute attr) {
auto type = attr.getType().dyn_cast<RankedTensorType>(); auto type = attr.getType().dyn_cast<RankedTensorType>();
if (!type) if (!type)
@ -47,18 +47,18 @@ convertAttrToTensor(Attribute attr) {
if (elementType.isF32()) { if (elementType.isF32()) {
auto values = llvm::to_vector<100>(llvm::map_range( auto values = llvm::to_vector<100>(llvm::map_range(
denseFp, [](APFloat f) { return f.convertToFloat(); })); denseFp, [](APFloat f) { return f.convertToFloat(); }));
return npcomprt::Tensor::create( return refbackrt::Tensor::create(
npcomprt::ArrayRef<std::int32_t>(extents.data(), extents.size()), refbackrt::ArrayRef<std::int32_t>(extents.data(), extents.size()),
npcomprt::ElementType::F32, static_cast<void *>(values.data())); refbackrt::ElementType::F32, static_cast<void *>(values.data()));
} }
} }
return make_string_error("unhandled argument"); return make_string_error("unhandled argument");
} }
static Expected<SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6>> static Expected<SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6>>
createInputs(ArrayRef<StringRef> argValues) { createInputs(ArrayRef<StringRef> argValues) {
MLIRContext context; MLIRContext context;
SmallVector<npcomprt::Ref<npcomprt::Tensor>, 6> ret; SmallVector<refbackrt::Ref<refbackrt::Tensor>, 6> ret;
for (auto argValue : argValues) { for (auto argValue : argValues) {
auto attr = parseAttribute(argValue, &context); auto attr = parseAttribute(argValue, &context);
if (!attr) if (!attr)
@ -71,14 +71,14 @@ createInputs(ArrayRef<StringRef> argValues) {
return ret; return ret;
} }
static Type convertToMLIRType(npcomprt::ElementType type, Builder &builder) { static Type convertToMLIRType(refbackrt::ElementType type, Builder &builder) {
switch (type) { switch (type) {
case npcomprt::ElementType::F32: case refbackrt::ElementType::F32:
return builder.getF32Type(); return builder.getF32Type();
} }
} }
static RankedTensorType getCorrespondingMLIRTensorType(npcomprt::Tensor &tensor, static RankedTensorType getCorrespondingMLIRTensorType(refbackrt::Tensor &tensor,
Builder &builder) { Builder &builder) {
auto elementType = convertToMLIRType(tensor.getElementType(), builder); auto elementType = convertToMLIRType(tensor.getElementType(), builder);
SmallVector<int64_t, 6> extents; SmallVector<int64_t, 6> extents;
@ -87,11 +87,11 @@ static RankedTensorType getCorrespondingMLIRTensorType(npcomprt::Tensor &tensor,
return RankedTensorType::get(extents, elementType); return RankedTensorType::get(extents, elementType);
} }
static Attribute convertToMLIRAttribute(npcomprt::Tensor &tensor, static Attribute convertToMLIRAttribute(refbackrt::Tensor &tensor,
Builder &builder) { Builder &builder) {
RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder); RankedTensorType type = getCorrespondingMLIRTensorType(tensor, builder);
switch (tensor.getElementType()) { switch (tensor.getElementType()) {
case npcomprt::ElementType::F32: { case refbackrt::ElementType::F32: {
SmallVector<float, 100> values; SmallVector<float, 100> values;
auto *basePtr = tensor.getData<float>(); auto *basePtr = tensor.getData<float>();
for (int i = 0, e = type.getNumElements(); i < e; i++) for (int i = 0, e = type.getNumElements(); i < e; i++)
@ -101,14 +101,14 @@ static Attribute convertToMLIRAttribute(npcomprt::Tensor &tensor,
} }
} }
static void printOutput(npcomprt::Tensor &tensor, llvm::raw_ostream &os) { static void printOutput(refbackrt::Tensor &tensor, llvm::raw_ostream &os) {
MLIRContext context; MLIRContext context;
Builder builder(&context); Builder builder(&context);
auto attr = convertToMLIRAttribute(tensor, builder); auto attr = convertToMLIRAttribute(tensor, builder);
attr.print(os); attr.print(os);
} }
static void printOutputs(ArrayRef<npcomprt::Ref<npcomprt::Tensor>> outputs, static void printOutputs(ArrayRef<refbackrt::Ref<refbackrt::Tensor>> outputs,
llvm::raw_ostream &os) { llvm::raw_ostream &os) {
for (auto output : llvm::enumerate(outputs)) { for (auto output : llvm::enumerate(outputs)) {
os << "output #" << output.index() << ": "; os << "output #" << output.index() << ": ";