mirror of https://github.com/llvm/torch-mlir
npcomprt: add support for constants
- create tcp.global + tcp.get_global_memref - create npcomprt.global + npcomprt.get_global - LLVM lowering for new npcomprt ops - Runtime: - GlobalDescriptor struct emitted by LLVM lowering - implement __npcomp_compiler_rt_get_global Also, - cleanly isolate all runtime data structure definitions shared by the compiler and runtime into lib/runtime/CompilerDataStructures.hpull/1/head
parent
2e40ce05ad
commit
e228aa4b11
|
@ -27,7 +27,7 @@ def Npcomprt_Tensor
|
||||||
: DialectType<
|
: DialectType<
|
||||||
Npcomprt_Dialect,
|
Npcomprt_Dialect,
|
||||||
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">,
|
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">,
|
||||||
"buffer view">,
|
"npcomprt.tensor">,
|
||||||
BuildableType<
|
BuildableType<
|
||||||
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> {
|
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> {
|
||||||
let typeDescription = [{The runtime type that represents a buffer.}];
|
let typeDescription = [{The runtime type that represents a buffer.}];
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
namespace NPCOMP {
|
namespace NPCOMP {
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
#define NPCOMPRT_OPS
|
#define NPCOMPRT_OPS
|
||||||
|
|
||||||
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td"
|
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []>
|
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<Npcomprt_Dialect, mnemonic, traits> {
|
: Op<Npcomprt_Dialect, mnemonic, traits> {
|
||||||
|
@ -57,6 +58,44 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
||||||
let assemblyFormat = "$pred attr-dict";
|
let assemblyFormat = "$pred attr-dict";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
||||||
|
let summary = "Represents a global variable";
|
||||||
|
let description = [{
|
||||||
|
Represents a global variable.
|
||||||
|
|
||||||
|
Currently, only constant tensors are supported, and they are not
|
||||||
|
considered to be exported.
|
||||||
|
}];
|
||||||
|
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
||||||
|
let summary = "Obtain a rank-erased memref pointing at the given global";
|
||||||
|
let description = [{
|
||||||
|
Obtain a rank-erased memref pointing at the given global.
|
||||||
|
|
||||||
|
TODO: As we define the runtime layer better, we should have fewer
|
||||||
|
entry points that return memrefs, or at least have a clearer separation
|
||||||
|
between the "memref world" and the "npcomprt world".
|
||||||
|
Something like forming IREE dispatch regions seems to be the missing thing:
|
||||||
|
- Everything inside the dispatch regions gets things marshaled from the
|
||||||
|
runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way.
|
||||||
|
- Everything outside the dispatch regions purely uses the runtime
|
||||||
|
(flow/hal/npcomprt) data structures.
|
||||||
|
Globals should be one of the things that are purely runtime data structures,
|
||||||
|
rather than using memrefs. For now, using memrefs is simpler though.
|
||||||
|
}];
|
||||||
|
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||||
|
let results = (outs AnyUnrankedMemRef:$memref);
|
||||||
|
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||||
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
|
// TODO: verify exists and shape is compatible
|
||||||
|
}
|
||||||
|
|
||||||
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
||||||
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
||||||
]> {
|
]> {
|
||||||
|
|
|
@ -13,6 +13,7 @@
|
||||||
#include "mlir/IR/OpDefinition.h"
|
#include "mlir/IR/OpDefinition.h"
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
#include "mlir/IR/SymbolTable.h"
|
||||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -13,6 +13,7 @@ include "npcomp/Dialect/TCP/IR/TCPBase.td"
|
||||||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||||
|
include "mlir/IR/SymbolInterfaces.td"
|
||||||
|
|
||||||
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
||||||
: Op<TCP_Dialect, mnemonic, traits> {
|
: Op<TCP_Dialect, mnemonic, traits> {
|
||||||
|
@ -58,6 +59,32 @@ Allocates a memref of the given shape.
|
||||||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
||||||
|
let summary = "Represents a global variable";
|
||||||
|
let description = [{
|
||||||
|
Represents a global variable.
|
||||||
|
|
||||||
|
Currently, only constant tensors are supported, and they are not
|
||||||
|
considered to be exported.
|
||||||
|
}];
|
||||||
|
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||||
|
let results = (outs);
|
||||||
|
|
||||||
|
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||||
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
|
}
|
||||||
|
|
||||||
|
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
||||||
|
let summary = "Obtain a memref pointing at the given global";
|
||||||
|
let description = [{
|
||||||
|
Obtain a memref pointing at the given global.
|
||||||
|
}];
|
||||||
|
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||||
|
let results = (outs AnyMemRef:$memref);
|
||||||
|
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||||
|
let verifier = "return ::verify$cppClass(*this);";
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Change to a more principled error handling mechanism.
|
// TODO: Change to a more principled error handling mechanism.
|
||||||
// This op probably doesn't need to exist eventually.
|
// This op probably doesn't need to exist eventually.
|
||||||
// This op is also not correctly modeled right now, since it itself doesn't
|
// This op is also not correctly modeled right now, since it itself doesn't
|
||||||
|
|
|
@ -25,6 +25,9 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
||||||
std::unique_ptr<OperationPass<FuncOp>>
|
std::unique_ptr<OperationPass<FuncOp>>
|
||||||
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
createLowerConstantTensorsToMemrefsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
|
||||||
|
|
||||||
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
||||||
|
|
|
@ -23,6 +23,15 @@ def LowerBroadcastToToLoops :
|
||||||
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
|
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def LowerConstantTensorsToMemrefs :
|
||||||
|
Pass<"lower-constant-tensors-to-memrefs", "ModuleOp"> {
|
||||||
|
let summary = "Lower std.constant of tensor type to hybrid tensor/memref.";
|
||||||
|
let description = [{
|
||||||
|
This has to be a module pass since it involves creating tcp.global ops.
|
||||||
|
}];
|
||||||
|
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass()";
|
||||||
|
}
|
||||||
|
|
||||||
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
||||||
let summary = "Resolve shape.shape_of ops to other shapes.";
|
let summary = "Resolve shape.shape_of ops to other shapes.";
|
||||||
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
||||||
|
|
|
@ -10,11 +10,53 @@
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/SymbolTable.h"
|
#include "mlir/IR/SymbolTable.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP::npcomprt;
|
using namespace mlir::NPCOMP::npcomprt;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GlobalOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||||
|
p << "npcomprt.global ";
|
||||||
|
p.printSymbolName(op.sym_name());
|
||||||
|
p << ' ';
|
||||||
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||||
|
/*elidedAttrs=*/{"sym_name", "value"});
|
||||||
|
p.printAttribute(op.valueAttr());
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
StringAttr nameAttr;
|
||||||
|
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||||
|
result.attributes))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||||
|
return failure();
|
||||||
|
Attribute valueAttr;
|
||||||
|
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GetGlobalOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
||||||
|
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||||
|
if (!global)
|
||||||
|
return op.emitError() << "must reference a valid npcomprt.global";
|
||||||
|
auto globalType = global.value().getType();
|
||||||
|
auto resultType = op.getType().cast<ShapedType>();
|
||||||
|
if (globalType.getElementType() != resultType.getElementType())
|
||||||
|
return op.emitError() << "inconsistent with element type of global";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ModuleMetadataOp
|
// ModuleMetadataOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -8,11 +8,56 @@
|
||||||
|
|
||||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||||
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::NPCOMP;
|
using namespace mlir::NPCOMP;
|
||||||
using namespace mlir::NPCOMP::tcp;
|
using namespace mlir::NPCOMP::tcp;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GlobalOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||||
|
p << "tcp.global ";
|
||||||
|
p.printSymbolName(op.sym_name());
|
||||||
|
p << ' ';
|
||||||
|
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||||
|
/*elidedAttrs=*/{"sym_name", "value"});
|
||||||
|
p.printAttribute(op.valueAttr());
|
||||||
|
}
|
||||||
|
|
||||||
|
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||||
|
StringAttr nameAttr;
|
||||||
|
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||||
|
result.attributes))
|
||||||
|
return failure();
|
||||||
|
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||||
|
return failure();
|
||||||
|
Attribute valueAttr;
|
||||||
|
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// GetGlobalMemrefOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
|
||||||
|
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||||
|
if (!global)
|
||||||
|
return op.emitError() << "must reference a valid symbol";
|
||||||
|
auto globalType = global.value().getType();
|
||||||
|
auto resultType = op.getType().cast<ShapedType>();
|
||||||
|
if (failed(
|
||||||
|
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
|
||||||
|
return op.emitError() << "inconsistent with shape of global";
|
||||||
|
if (globalType.getElementType() != resultType.getElementType())
|
||||||
|
return op.emitError() << "inconsistent with element type of global";
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// ShapeObserveErrorOp
|
// ShapeObserveErrorOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -359,7 +359,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
||||||
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
||||||
|
|
||||||
// At this point, the IR is in a form where there are no tensor ops
|
// At this point, the IR is in a form where there are no tensor ops
|
||||||
// (except tensor_store's of arguments and tensor_load's of returns).
|
// (except tensor_store's of arguments, tensor_load's of returns, and
|
||||||
|
// constants).
|
||||||
//
|
//
|
||||||
// This is a reasonable representation for doing buffer assignment.
|
// This is a reasonable representation for doing buffer assignment.
|
||||||
// TODO: Do buffer assignment here.
|
// TODO: Do buffer assignment here.
|
||||||
|
|
|
@ -296,6 +296,91 @@ mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() {
|
||||||
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
|
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// LowerConstantTensorsToMemrefs
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// This class creates global ops for all tensor-valued constants in the program.
|
||||||
|
// It creates them with pretty names and makes sure that duplicate globals
|
||||||
|
// aren't created.
|
||||||
|
class GlobalCreator {
|
||||||
|
public:
|
||||||
|
explicit GlobalCreator(ModuleOp module);
|
||||||
|
tcp::GlobalOp getGlobalFor(Attribute attr) {
|
||||||
|
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
||||||
|
return globals[attr];
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
DenseMap<Attribute, tcp::GlobalOp> globals;
|
||||||
|
};
|
||||||
|
|
||||||
|
GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||||
|
// Create a builder without an insertion point. We will insert using the
|
||||||
|
// symbol table to guarantee unique names.
|
||||||
|
OpBuilder globalBuilder(module.getContext());
|
||||||
|
SymbolTable symbolTable(module);
|
||||||
|
module.walk([&](ConstantOp op) {
|
||||||
|
// We only want tensor constants for now.
|
||||||
|
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type)
|
||||||
|
return;
|
||||||
|
// If we already have a global for this constant value, no need to do
|
||||||
|
// anything else.
|
||||||
|
auto it = globals.find(op.getValue());
|
||||||
|
if (it != globals.end())
|
||||||
|
return;
|
||||||
|
|
||||||
|
// Create a pretty name.
|
||||||
|
SmallString<64> buf;
|
||||||
|
llvm::raw_svector_ostream os(buf);
|
||||||
|
interleave(type.getShape(), os, "x");
|
||||||
|
os << "x" << type.getElementType();
|
||||||
|
|
||||||
|
auto global = globalBuilder.create<tcp::GlobalOp>(
|
||||||
|
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||||
|
op.getValue().cast<ElementsAttr>());
|
||||||
|
symbolTable.insert(global);
|
||||||
|
// The symbol table inserts at the end of the module, but globals are a bit
|
||||||
|
// nicer if they are at the beginning.
|
||||||
|
global.getOperation()->moveBefore(&module.front());
|
||||||
|
globals[op.getValue()] = global;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerConstantTensorsToMemrefs
|
||||||
|
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
||||||
|
void runOnOperation() {
|
||||||
|
auto module = getOperation();
|
||||||
|
GlobalCreator globals(module);
|
||||||
|
|
||||||
|
// With the global traversal factored into GlobalCreator, this could in
|
||||||
|
// principle be done with a pattern.
|
||||||
|
module.walk([&](ConstantOp op) {
|
||||||
|
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||||
|
if (!type)
|
||||||
|
return;
|
||||||
|
auto global = globals.getGlobalFor(op.getValue());
|
||||||
|
OpBuilder builder(op);
|
||||||
|
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
||||||
|
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
|
||||||
|
op.getLoc(), memrefType, global.getName());
|
||||||
|
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
|
||||||
|
op.replaceAllUsesWith(tensor);
|
||||||
|
op.erase();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
std::unique_ptr<OperationPass<ModuleOp>>
|
||||||
|
mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass() {
|
||||||
|
return std::make_unique<LowerConstantTensorsToMemrefs>();
|
||||||
|
}
|
||||||
|
|
||||||
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
||||||
// Lower to hybrid tensor/memref.
|
// Lower to hybrid tensor/memref.
|
||||||
// The invariant of "hybrid tensor/memref" is that the core computation
|
// The invariant of "hybrid tensor/memref" is that the core computation
|
||||||
|
@ -305,6 +390,7 @@ void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
||||||
// allocated with alloc_shape ops.
|
// allocated with alloc_shape ops.
|
||||||
// Thus, shape.shape_of ops on the original tensors in the program can be
|
// Thus, shape.shape_of ops on the original tensors in the program can be
|
||||||
// resolved to the shapes in the alloc_memref calls.
|
// resolved to the shapes in the alloc_memref calls.
|
||||||
|
pm.addPass(createLowerConstantTensorsToMemrefsPass());
|
||||||
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
|
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
|
||||||
pm.addPass(createLowerBroadcastToToLoopsPass());
|
pm.addPass(createLowerBroadcastToToLoopsPass());
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,28 @@ using namespace mlir::NPCOMP;
|
||||||
using mlir::LLVM::LLVMFuncOp;
|
using mlir::LLVM::LLVMFuncOp;
|
||||||
using mlir::LLVM::LLVMType;
|
using mlir::LLVM::LLVMType;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Utilities.
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// TODO: Move other descriptor types to here.
|
||||||
|
|
||||||
|
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
||||||
|
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||||
|
return LLVMType::getStructTy(
|
||||||
|
// std::int32_t numExtents;
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
// std::int32_t *extents;
|
||||||
|
LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
|
||||||
|
// It is important that this struct member is a type-erased pointer
|
||||||
|
// so that this type is "context-free" and can be created in conversion
|
||||||
|
// patterns independently of the actual type of the data stored in the
|
||||||
|
// buffer.
|
||||||
|
//
|
||||||
|
// void *data;
|
||||||
|
LLVMType::getInt8PtrTy(llvmDialect));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Compiler runtime functions.
|
// Compiler runtime functions.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -72,6 +94,36 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class GetGlobalOpCompilerRuntimeLowering
|
||||||
|
: public OpConversionPattern<npcomprt::GetGlobalOp> {
|
||||||
|
public:
|
||||||
|
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||||
|
: OpConversionPattern<npcomprt::GetGlobalOp>(backingFunc.getContext()),
|
||||||
|
backingFunc(backingFunc) {}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto llvmDialect =
|
||||||
|
rewriter.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||||
|
// It would be nice if we could use the constructor here that takes just the
|
||||||
|
// global, but keeping track of the converted llvm.mlir.global op that gets
|
||||||
|
// created from the npcomprt.global while conversion is going on is a
|
||||||
|
// headache.
|
||||||
|
//
|
||||||
|
// Instead, we rely on the symbol name being the same and the result type
|
||||||
|
// always being the same.
|
||||||
|
auto globalAddr = rewriter.create<LLVM::AddressOfOp>(
|
||||||
|
op.getLoc(), getGlobalDescriptorTy(llvmDialect).getPointerTo(),
|
||||||
|
op.globalAttr());
|
||||||
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc,
|
||||||
|
ValueRange({globalAddr}));
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
LLVM::LLVMFuncOp backingFunc;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
||||||
// and requiring `type`.
|
// and requiring `type`.
|
||||||
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
||||||
|
@ -140,8 +192,164 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
||||||
"from_memref", funcTy, builder, module.getLoc());
|
"from_memref", funcTy, builder, module.getLoc());
|
||||||
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
|
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// Hardcoding f32 is fine here, since unranked memref descriptors have
|
||||||
|
// identical struct layout / ABI / contents regardless of the element type.
|
||||||
|
auto mlirFunctionType = builder.getFunctionType(
|
||||||
|
{getGlobalDescriptorTy(llvmDialect).getPointerTo()},
|
||||||
|
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
|
||||||
|
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
||||||
|
LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl(
|
||||||
|
"get_global", funcTy, builder, module.getLoc());
|
||||||
|
patterns.insert<GetGlobalOpCompilerRuntimeLowering>(backingFunc);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Lowering for npcomprt.global
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerNpcomprtGlobalOp : public OpConversionPattern<npcomprt::GlobalOp> {
|
||||||
|
public:
|
||||||
|
explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter)
|
||||||
|
: OpConversionPattern<npcomprt::GlobalOp>(&typeConverter.getContext()),
|
||||||
|
typeConverter(typeConverter) {}
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(npcomprt::GlobalOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto llvmDialect = typeConverter.getDialect();
|
||||||
|
auto globalDescriptorTy = getGlobalDescriptorTy(llvmDialect);
|
||||||
|
|
||||||
|
// Create the data buffer.
|
||||||
|
auto dataBuffer = createGlobalForDenseElementsAttr(
|
||||||
|
(Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(),
|
||||||
|
op.value().cast<DenseElementsAttr>(), op, rewriter);
|
||||||
|
|
||||||
|
// Create the extents buffer.
|
||||||
|
auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>(
|
||||||
|
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
|
||||||
|
[](int64_t i) -> int32_t { return i; })));
|
||||||
|
auto extentsBuffer = createGlobalForDenseElementsAttr(
|
||||||
|
(Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32,
|
||||||
|
op, rewriter);
|
||||||
|
|
||||||
|
// Create the GlobalDescriptor.
|
||||||
|
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
|
||||||
|
op.getLoc(), globalDescriptorTy, /*isConstant=*/true,
|
||||||
|
LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute());
|
||||||
|
|
||||||
|
OpBuilder::InsertionGuard guard(rewriter);
|
||||||
|
rewriter.createBlock(&globalDescriptorGlobal.initializer());
|
||||||
|
|
||||||
|
// Create the body of the initializer.
|
||||||
|
Value globalDescriptor =
|
||||||
|
rewriter.create<LLVM::UndefOp>(op.getLoc(), globalDescriptorTy);
|
||||||
|
auto updateDescriptor = [&](Value value,
|
||||||
|
std::initializer_list<int32_t> position) {
|
||||||
|
globalDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||||
|
op.getLoc(), globalDescriptor, value,
|
||||||
|
/*position=*/rewriter.getI32ArrayAttr(position));
|
||||||
|
};
|
||||||
|
updateDescriptor(
|
||||||
|
rewriter.create<LLVM::ConstantOp>(
|
||||||
|
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32),
|
||||||
|
rewriter.getI32IntegerAttr(
|
||||||
|
op.value().getType().cast<ShapedType>().getRank())),
|
||||||
|
{0});
|
||||||
|
|
||||||
|
// The global is actually an array, so we need to get a bare i32* pointer
|
||||||
|
// type. We could do this with GEP but it would be more verbose.
|
||||||
|
auto extentsBufferArrayAddress =
|
||||||
|
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), extentsBuffer);
|
||||||
|
auto extentsBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
|
||||||
|
extentsBufferArrayAddress);
|
||||||
|
updateDescriptor(extentsBufferAddress, {1});
|
||||||
|
|
||||||
|
auto dataBufferAddress =
|
||||||
|
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), dataBuffer);
|
||||||
|
auto typeErasedDataBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||||
|
op.getLoc(), LLVMType::getInt8PtrTy(llvmDialect), dataBufferAddress);
|
||||||
|
updateDescriptor(typeErasedDataBufferAddress, {2});
|
||||||
|
rewriter.create<LLVM::ReturnOp>(op.getLoc(), globalDescriptor);
|
||||||
|
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// TODO: It feels like MLIR core should have better utilities for this.
|
||||||
|
LLVM::GlobalOp createGlobalForDenseElementsAttr(
|
||||||
|
StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto type = elements.getType().cast<ShapedType>();
|
||||||
|
|
||||||
|
// LLVM translation doesn't handle the case of zero-sized tensors, which can
|
||||||
|
// happen e.g. for the number of extents of a rank-0 (i.e. scalar).
|
||||||
|
//
|
||||||
|
// We fake-up a size-1 DenseElementsAttr to use for creating the global.
|
||||||
|
// That takes up binary space (one element instead of zero), but that seems
|
||||||
|
// fine.
|
||||||
|
//
|
||||||
|
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||||
|
if (type.getNumElements() == 0) {
|
||||||
|
auto elementType = type.getElementType();
|
||||||
|
Attribute singleElement;
|
||||||
|
if (elementType.isIntOrIndex())
|
||||||
|
singleElement = rewriter.getIntegerAttr(elementType, 0);
|
||||||
|
else if (elementType.isa<FloatType>())
|
||||||
|
singleElement = rewriter.getFloatAttr(elementType, 0);
|
||||||
|
assert(singleElement &&
|
||||||
|
"could not fake up an element for a zero element tensor");
|
||||||
|
type = RankedTensorType::get({1}, elementType);
|
||||||
|
elements =
|
||||||
|
DenseElementsAttr::get(type, ArrayRef<Attribute>(singleElement));
|
||||||
|
}
|
||||||
|
|
||||||
|
auto llvmType = getLLVMTypeForShapedType(type, op, rewriter);
|
||||||
|
return rewriter.create<LLVM::GlobalOp>(
|
||||||
|
op.getLoc(), llvmType,
|
||||||
|
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
|
||||||
|
}
|
||||||
|
|
||||||
|
LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op,
|
||||||
|
ConversionPatternRewriter &rewriter) const {
|
||||||
|
auto llvmType =
|
||||||
|
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
|
||||||
|
|
||||||
|
// MLIR->LLVM lowering for globals requires non-scalar data types. So use a
|
||||||
|
// dummy size-1 array for the scalar case.
|
||||||
|
//
|
||||||
|
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||||
|
if (type.getRank() == 0)
|
||||||
|
return LLVMType::getArrayTy(llvmType, 1);
|
||||||
|
|
||||||
|
if (!llvmType) {
|
||||||
|
rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||||
|
diag << "cannot convert element type " << type.getElementType()
|
||||||
|
<< " to an LLVM type";
|
||||||
|
});
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct an LLVM nested array type for the tensor initializer.
|
||||||
|
// tensor<f32> -> float
|
||||||
|
// tensor<10xf32> -> [10 x float]
|
||||||
|
// tensor<2x3xf32> -> [2 x [3 x float]]
|
||||||
|
assert(type.hasStaticShape());
|
||||||
|
auto shape = type.getShape();
|
||||||
|
while (!shape.empty()) {
|
||||||
|
llvmType = LLVMType::getArrayTy(llvmType, shape.back());
|
||||||
|
shape = shape.drop_back();
|
||||||
|
}
|
||||||
|
return llvmType;
|
||||||
|
}
|
||||||
|
LLVMTypeConverter &typeConverter;
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Lowering for module metadata
|
// Lowering for module metadata
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -443,6 +651,7 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
||||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||||
patterns.insert<LowerModuleMetadata>(context);
|
patterns.insert<LowerModuleMetadata>(context);
|
||||||
|
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
||||||
|
|
||||||
if (failed(applyFullConversion(module, target, patterns))) {
|
if (failed(applyFullConversion(module, target, patterns))) {
|
||||||
return signalPassFailure();
|
return signalPassFailure();
|
||||||
|
|
|
@ -139,6 +139,39 @@ public:
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::GlobalOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
|
||||||
|
op.value());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
class LowerGetGlobalMemrefOp
|
||||||
|
: public OpConversionPattern<tcp::GetGlobalMemrefOp> {
|
||||||
|
public:
|
||||||
|
using OpConversionPattern::OpConversionPattern;
|
||||||
|
LogicalResult
|
||||||
|
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||||
|
ConversionPatternRewriter &rewriter) const override {
|
||||||
|
auto abiMemrefType = UnrankedMemRefType::get(
|
||||||
|
op.getType().cast<ShapedType>().getElementType(), /*memorySpace=*/0);
|
||||||
|
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
||||||
|
op.getLoc(), abiMemrefType, op.global());
|
||||||
|
// Cast back to the original type.
|
||||||
|
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace
|
||||||
|
|
||||||
static LogicalResult doDialectConversion(ModuleOp module) {
|
static LogicalResult doDialectConversion(ModuleOp module) {
|
||||||
auto *context = module.getContext();
|
auto *context = module.getContext();
|
||||||
|
|
||||||
|
@ -172,6 +205,14 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
||||||
target.addLegalOp<shape::FromExtentsOp>();
|
target.addLegalOp<shape::FromExtentsOp>();
|
||||||
target.addLegalOp<npcomprt::GetExtentOp>();
|
target.addLegalOp<npcomprt::GetExtentOp>();
|
||||||
|
|
||||||
|
patterns.insert<LowerGlobalOp>(context);
|
||||||
|
target.addIllegalOp<tcp::GlobalOp>();
|
||||||
|
target.addLegalOp<npcomprt::GlobalOp>();
|
||||||
|
|
||||||
|
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
||||||
|
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
|
||||||
|
target.addLegalOp<npcomprt::GetGlobalOp>();
|
||||||
|
|
||||||
return applyPartialConversion(module, target, patterns);
|
return applyPartialConversion(module, target, patterns);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file contains data structures (which we typically call "descriptors")
|
||||||
|
// that are emitted by the compiler and must be kept in sync with the compiler
|
||||||
|
// code that creates them in LowerToLLVM.cpp.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||||
|
#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace npcomprt {
|
||||||
|
|
||||||
|
// All arguments are packed into this type-erased form for being invoked. See
|
||||||
|
// LowerToLLVM.cpp for more details.
|
||||||
|
typedef void ABIFunc(void **, void **);
|
||||||
|
|
||||||
|
struct FuncDescriptor {
|
||||||
|
// The length of the function name.
|
||||||
|
std::int32_t nameLen;
|
||||||
|
// The name of the function, to allow lookup.
|
||||||
|
const char *name;
|
||||||
|
// This is a raw function pointer to the function's entry point as
|
||||||
|
// emitted by the compiler.
|
||||||
|
ABIFunc *functionPtr;
|
||||||
|
// The number of inputs to the function.
|
||||||
|
std::int32_t numInputs;
|
||||||
|
// The number of outputs of the function.
|
||||||
|
std::int32_t numOutputs;
|
||||||
|
// TODO: Add arg/result descriptors and other metadata.
|
||||||
|
// With those descriptors we can do type and shape checking for each
|
||||||
|
// argument.
|
||||||
|
};
|
||||||
|
|
||||||
|
// The top-level entry point of the module metadata emitted by the
|
||||||
|
// compiler. Unlike all the other descriptors here, external code does handle
|
||||||
|
// this type (albeit through an opaque pointer).
|
||||||
|
struct ModuleDescriptor {
|
||||||
|
std::int32_t numFuncDescriptors;
|
||||||
|
FuncDescriptor *functionDescriptors;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Static data representing a global variable (together with its shape).
|
||||||
|
struct GlobalDescriptor {
|
||||||
|
std::int32_t numExtents;
|
||||||
|
std::int32_t *extents;
|
||||||
|
void *data;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace npcomprt
|
||||||
|
|
||||||
|
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
|
@ -15,6 +15,7 @@
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
#include "CompilerDataStructures.h"
|
||||||
#include "npcomp/runtime/UserAPI.h"
|
#include "npcomp/runtime/UserAPI.h"
|
||||||
|
|
||||||
using namespace npcomprt;
|
using namespace npcomprt;
|
||||||
|
@ -118,4 +119,12 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank,
|
||||||
extents32Buf[i] = extents64[i];
|
extents32Buf[i] = extents64[i];
|
||||||
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
||||||
elementType, data);
|
elementType, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" UnrankedMemref
|
||||||
|
__npcomp_compiler_rt_get_global(GlobalDescriptor *global) {
|
||||||
|
auto *descriptor = MemrefDescriptor::create(
|
||||||
|
ArrayRef<std::int32_t>(global->extents, global->numExtents),
|
||||||
|
global->data);
|
||||||
|
return UnrankedMemref{global->numExtents, descriptor};
|
||||||
}
|
}
|
|
@ -13,6 +13,8 @@
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
|
#include "CompilerDataStructures.h"
|
||||||
|
|
||||||
using namespace npcomprt;
|
using namespace npcomprt;
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -63,39 +65,7 @@ std::int32_t Tensor::getDataByteSize() const {
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Module metadata descriptors.
|
// Module metadata descriptors.
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// These descriptors are never created by runtime code. They are always
|
|
||||||
// embedded by the compiler as static data inside the module.
|
|
||||||
//
|
|
||||||
// Their definitions need to be kept in sync with the compiler code in
|
|
||||||
// LowerToLLVM.cpp
|
|
||||||
|
|
||||||
// All arguments are packed into this type-erased form for being invoked. See
|
|
||||||
// LowerToLLVM.cpp for more details.
|
|
||||||
typedef void ABIFunc(void **, void **);
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
struct FuncDescriptor {
|
|
||||||
// The length of the function name.
|
|
||||||
std::int32_t nameLen;
|
|
||||||
// The name of the function, to allow lookup.
|
|
||||||
const char *name;
|
|
||||||
// This is a raw function pointer to the function's entry point as
|
|
||||||
// emitted by the compiler.
|
|
||||||
ABIFunc *functionPtr;
|
|
||||||
std::int32_t numInputs;
|
|
||||||
std::int32_t numOutputs;
|
|
||||||
// TODO: Add arg/result descriptors and other metadata.
|
|
||||||
// With those descriptors.
|
|
||||||
};
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
// The top-level entry point of the module metadata emitted by the
|
|
||||||
// compiler.
|
|
||||||
struct npcomprt::ModuleDescriptor {
|
|
||||||
std::int32_t numFuncDescriptors;
|
|
||||||
// TODO: Update compiler code to emit this as a separate global.
|
|
||||||
FuncDescriptor *functionDescriptors;
|
|
||||||
};
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Module operations.
|
// Module operations.
|
||||||
|
|
|
@ -21,3 +21,23 @@ npcomprt.module_metadata {
|
||||||
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
||||||
}
|
}
|
||||||
func @f() { return }
|
func @f() { return }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{must reference a valid npcomprt.global}}
|
||||||
|
npcomprt.get_global @nonexistent_symbol : memref<*xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{inconsistent with element type of global}}
|
||||||
|
npcomprt.get_global @g : memref<*xi8>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
|
@ -12,3 +12,9 @@ func @f(%arg0: !npcomprt.tensor) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32>
|
||||||
|
npcomprt.global @g dense<0.0> : tensor<10xf32>
|
||||||
|
func @uses_global() {
|
||||||
|
npcomprt.get_global @g : memref<*xf32>
|
||||||
|
return
|
||||||
|
}
|
|
@ -0,0 +1,31 @@
|
||||||
|
// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{must reference a valid symbol}}
|
||||||
|
tcp.get_global_memref @nonexistent_symbol : memref<3xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{inconsistent with shape of global}}
|
||||||
|
tcp.get_global_memref @g : memref<3xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||||
|
|
||||||
|
func @f() {
|
||||||
|
// expected-error @+1 {{inconsistent with element type of global}}
|
||||||
|
tcp.get_global_memref @g : memref<2xi8>
|
||||||
|
return
|
||||||
|
}
|
|
@ -1,7 +1,11 @@
|
||||||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||||
|
tcp.global @foo dense<0.0> : tensor<10xf32>
|
||||||
|
|
||||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||||
// CHECK: tcp.add
|
// CHECK: tcp.add
|
||||||
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||||
|
%1 = tcp.get_global_memref @foo : memref<10xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memrefs <%s | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK-LABEL: module {
|
||||||
|
// We check the debug name too since we put some effort into making that readable.
|
||||||
|
// The name isn't load-bearing though.
|
||||||
|
// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
|
||||||
|
// CHECK: func @basic
|
||||||
|
func @basic() -> tensor<3x4xf32> {
|
||||||
|
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||||
|
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
|
||||||
|
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||||
|
// CHECK: return %[[TENSOR]]
|
||||||
|
return %0 : tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module {
|
||||||
|
|
||||||
|
// Only one global is created.
|
||||||
|
// CHECK: tcp.global
|
||||||
|
// CHECK-NOT: tcp.global
|
||||||
|
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||||
|
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||||
|
%1 = constant dense<7.0> : tensor<3x4xf32>
|
||||||
|
// CHECK: return %[[TENSOR]]
|
||||||
|
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module {
|
||||||
|
|
||||||
|
// Two globals are created.
|
||||||
|
// CHECK: tcp.global
|
||||||
|
// CHECK: tcp.global
|
||||||
|
// CHECK-NOT: tcp.global
|
||||||
|
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||||
|
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||||
|
%1 = constant dense<8.0> : tensor<3x4xf32>
|
||||||
|
// CHECK: return %[[TENSOR]]
|
||||||
|
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: }
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: module {
|
||||||
|
// We don't convert non-tensor globals.
|
||||||
|
// CHECK-NOT: tcp.global
|
||||||
|
func @non_tensor() {
|
||||||
|
%0 = constant 7 : i32
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK: }
|
|
@ -0,0 +1,32 @@
|
||||||
|
// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||||
|
|
||||||
|
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm<"[3 x float]">
|
||||||
|
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm<"[1 x i32]">
|
||||||
|
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm<"{ i32, i32*, i8* }"> {
|
||||||
|
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, i32*, i8* }">
|
||||||
|
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||||
|
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||||
|
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm<"[1 x i32]*">
|
||||||
|
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x i32]*"> to !llvm<"i32*">
|
||||||
|
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||||
|
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm<"[3 x float]*">
|
||||||
|
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm<"[3 x float]*"> to !llvm<"i8*">
|
||||||
|
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||||
|
// CHECK: llvm.return %[[VAL_8]] : !llvm<"{ i32, i32*, i8* }">
|
||||||
|
// CHECK: }
|
||||||
|
// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm<"{ i64, i8* }"> {
|
||||||
|
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm<"{ i32, i32*, i8* }*">
|
||||||
|
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm<"{ i32, i32*, i8* }*">) -> !llvm<"{ i64, i8* }">
|
||||||
|
npcomprt.global @g dense<7.000000e+00> : tensor<3xf32>
|
||||||
|
func @calls_get_global() -> memref<*xf32> {
|
||||||
|
%0 = npcomprt.get_global @g : memref<*xf32>
|
||||||
|
return %0 : memref<*xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
|
||||||
|
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm<"[1 x float]">
|
||||||
|
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm<"[1 x i32]">
|
||||||
|
npcomprt.global @g dense<7.0> : tensor<f32>
|
||||||
|
|
|
@ -37,6 +37,23 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
|
||||||
|
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||||
|
tcp.global @g dense<7.0> : tensor<10xf32>
|
||||||
|
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||||
|
func @gets_global() -> tensor<10xf32> {
|
||||||
|
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
||||||
|
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||||
|
// CHECK: %[[RETMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||||
|
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[RETMEMREF]] : memref<*xf32>
|
||||||
|
// CHECK: return %[[RET]] : !npcomprt.tensor
|
||||||
|
%0 = tcp.get_global_memref @g : memref<10xf32>
|
||||||
|
%1 = tensor_load %0 : memref<10xf32>
|
||||||
|
return %1 : tensor<10xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
||||||
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
||||||
return
|
return
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: npcomp-run-mlir -input %s \
|
||||||
|
// RUN: -invoke constant_add_scalar \
|
||||||
|
// RUN: -arg-value="dense<3.0> : tensor<f32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: output #0: dense<4.000000e+00> : tensor<f32>
|
||||||
|
func @constant_add_scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||||
|
%0 = constant dense<1.0> : tensor<f32>
|
||||||
|
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||||
|
return %1 : tensor<f32>
|
||||||
|
}
|
|
@ -0,0 +1,12 @@
|
||||||
|
// RUN: npcomp-run-mlir -input %s \
|
||||||
|
// RUN: -invoke constant_add \
|
||||||
|
// RUN: -arg-value="dense<[3.0, 5.0]> : tensor<2xf32>" \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: output #0: dense<[4.000000e+00, 7.000000e+00]> : tensor<2xf32>
|
||||||
|
func @constant_add(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||||
|
%0 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||||
|
%1 = "tcf.add"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||||
|
return %1 : tensor<2xf32>
|
||||||
|
}
|
|
@ -0,0 +1,10 @@
|
||||||
|
// RUN: npcomp-run-mlir -input %s \
|
||||||
|
// RUN: -invoke constant \
|
||||||
|
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||||
|
// RUN: | FileCheck %s
|
||||||
|
|
||||||
|
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
|
||||||
|
func @constant() -> tensor<f32> {
|
||||||
|
%0 = constant dense<1.0> : tensor<f32>
|
||||||
|
return %0 : tensor<f32>
|
||||||
|
}
|
|
@ -47,3 +47,5 @@ npctall() {
|
||||||
# https://superuser.com/q/253068
|
# https://superuser.com/q/253068
|
||||||
export FIGNORE=$FIGNORE:nstall-mlir
|
export FIGNORE=$FIGNORE:nstall-mlir
|
||||||
|
|
||||||
|
export PYTHONPATH="$(realpath ${build_dir}/python):$(realpath ${build_dir}/python_native):$(realpath ${build_dir}/iree/bindings/python)"
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue