npcomprt: add support for constants

- create tcp.global + tcp.get_global_memref
- create npcomprt.global + npcomprt.get_global
- LLVM lowering for new npcomprt ops
- Runtime:
 - GlobalDescriptor struct emitted by LLVM lowering
 - implement __npcomp_compiler_rt_get_global

Also,
- cleanly isolate all runtime data structure definitions shared by the
compiler and runtime into lib/runtime/CompilerDataStructures.h
pull/1/head
Sean Silva 2020-07-10 17:31:24 -07:00
parent 2e40ce05ad
commit e228aa4b11
27 changed files with 784 additions and 34 deletions

View File

@ -27,7 +27,7 @@ def Npcomprt_Tensor
: DialectType< : DialectType<
Npcomprt_Dialect, Npcomprt_Dialect,
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">, CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">,
"buffer view">, "npcomprt.tensor">,
BuildableType< BuildableType<
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> { "$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> {
let typeDescription = [{The runtime type that represents a buffer.}]; let typeDescription = [{The runtime type that represents a buffer.}];

View File

@ -12,6 +12,7 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
namespace mlir { namespace mlir {
namespace NPCOMP { namespace NPCOMP {

View File

@ -10,6 +10,7 @@
#define NPCOMPRT_OPS #define NPCOMPRT_OPS
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td" include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td"
include "mlir/IR/SymbolInterfaces.td"
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []> class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Npcomprt_Dialect, mnemonic, traits> { : Op<Npcomprt_Dialect, mnemonic, traits> {
@ -57,6 +58,44 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
let assemblyFormat = "$pred attr-dict"; let assemblyFormat = "$pred attr-dict";
} }
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
let summary = "Obtain a rank-erased memref pointing at the given global";
let description = [{
Obtain a rank-erased memref pointing at the given global.
TODO: As we define the runtime layer better, we should have fewer
entry points that return memrefs, or at least have a clearer separation
between the "memref world" and the "npcomprt world".
Something like forming IREE dispatch regions seems to be the missing thing:
- Everything inside the dispatch regions gets things marshaled from the
runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way.
- Everything outside the dispatch regions purely uses the runtime
(flow/hal/npcomprt) data structures.
Globals should be one of the things that are purely runtime data structures,
rather than using memrefs. For now, using memrefs is simpler though.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
// TODO: verify exists and shape is compatible
}
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [ def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp"> SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
]> { ]> {

View File

@ -13,6 +13,7 @@
#include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir { namespace mlir {

View File

@ -13,6 +13,7 @@ include "npcomp/Dialect/TCP/IR/TCPBase.td"
include "mlir/Dialect/Shape/IR/ShapeBase.td" include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/IR/SymbolInterfaces.td"
class TCP_Op<string mnemonic, list<OpTrait> traits = []> class TCP_Op<string mnemonic, list<OpTrait> traits = []>
: Op<TCP_Dialect, mnemonic, traits> { : Op<TCP_Dialect, mnemonic, traits> {
@ -58,6 +59,32 @@ Allocates a memref of the given shape.
let assemblyFormat = "$shape attr-dict `:` type($memref)"; let assemblyFormat = "$shape attr-dict `:` type($memref)";
} }
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
let summary = "Obtain a memref pointing at the given global";
let description = [{
Obtain a memref pointing at the given global.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
}
// TODO: Change to a more principled error handling mechanism. // TODO: Change to a more principled error handling mechanism.
// This op probably doesn't need to exist eventually. // This op probably doesn't need to exist eventually.
// This op is also not correctly modeled right now, since it itself doesn't // This op is also not correctly modeled right now, since it itself doesn't

View File

@ -25,6 +25,9 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
std::unique_ptr<OperationPass<FuncOp>> std::unique_ptr<OperationPass<FuncOp>>
createLowerLinalgOnTensorToLinalgOnMemrefPass(); createLowerLinalgOnTensorToLinalgOnMemrefPass();
std::unique_ptr<OperationPass<ModuleOp>>
createLowerConstantTensorsToMemrefsPass();
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass(); std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass(); std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();

View File

@ -23,6 +23,15 @@ def LowerBroadcastToToLoops :
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()"; let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
} }
def LowerConstantTensorsToMemrefs :
Pass<"lower-constant-tensors-to-memrefs", "ModuleOp"> {
let summary = "Lower std.constant of tensor type to hybrid tensor/memref.";
let description = [{
This has to be a module pass since it involves creating tcp.global ops.
}];
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass()";
}
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> { def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
let summary = "Resolve shape.shape_of ops to other shapes."; let summary = "Resolve shape.shape_of ops to other shapes.";
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()"; let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";

View File

@ -10,11 +10,53 @@
#include "mlir/IR/Builders.h" #include "mlir/IR/Builders.h"
#include "mlir/IR/Function.h" #include "mlir/IR/Function.h"
#include "mlir/IR/SymbolTable.h" #include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h" #include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP::npcomprt; using namespace mlir::NPCOMP::npcomprt;
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "npcomprt.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid npcomprt.global";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ModuleMetadataOp // ModuleMetadataOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -8,11 +8,56 @@
#include "npcomp/Dialect/TCP/IR/TCPOps.h" #include "npcomp/Dialect/TCP/IR/TCPOps.h"
#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/IR/TypeUtilities.h"
using namespace mlir; using namespace mlir;
using namespace mlir::NPCOMP; using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::tcp; using namespace mlir::NPCOMP::tcp;
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "tcp.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid symbol";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (failed(
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
return op.emitError() << "inconsistent with shape of global";
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// ShapeObserveErrorOp // ShapeObserveErrorOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -359,7 +359,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
pm.addPass(createResolveTensorLoadStoreOpsPass()); pm.addPass(createResolveTensorLoadStoreOpsPass());
// At this point, the IR is in a form where there are no tensor ops // At this point, the IR is in a form where there are no tensor ops
// (except tensor_store's of arguments and tensor_load's of returns). // (except tensor_store's of arguments, tensor_load's of returns, and
// constants).
// //
// This is a reasonable representation for doing buffer assignment. // This is a reasonable representation for doing buffer assignment.
// TODO: Do buffer assignment here. // TODO: Do buffer assignment here.

View File

@ -296,6 +296,91 @@ mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() {
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>(); return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
} }
//===----------------------------------------------------------------------===//
// LowerConstantTensorsToMemrefs
//===----------------------------------------------------------------------===//
namespace {
// This class creates global ops for all tensor-valued constants in the program.
// It creates them with pretty names and makes sure that duplicate globals
// aren't created.
class GlobalCreator {
public:
explicit GlobalCreator(ModuleOp module);
tcp::GlobalOp getGlobalFor(Attribute attr) {
assert(globals.find(attr) != globals.end() && "unknown constant attr");
return globals[attr];
}
private:
DenseMap<Attribute, tcp::GlobalOp> globals;
};
GlobalCreator::GlobalCreator(ModuleOp module) {
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(module.getContext());
SymbolTable symbolTable(module);
module.walk([&](ConstantOp op) {
// We only want tensor constants for now.
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return;
// If we already have a global for this constant value, no need to do
// anything else.
auto it = globals.find(op.getValue());
if (it != globals.end())
return;
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
auto global = globalBuilder.create<tcp::GlobalOp>(
op.getLoc(), (Twine("__constant_") + os.str()).str(),
op.getValue().cast<ElementsAttr>());
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global.getOperation()->moveBefore(&module.front());
globals[op.getValue()] = global;
});
}
} // namespace
namespace {
class LowerConstantTensorsToMemrefs
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
void runOnOperation() {
auto module = getOperation();
GlobalCreator globals(module);
// With the global traversal factored into GlobalCreator, this could in
// principle be done with a pattern.
module.walk([&](ConstantOp op) {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return;
auto global = globals.getGlobalFor(op.getValue());
OpBuilder builder(op);
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
op.getLoc(), memrefType, global.getName());
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
op.replaceAllUsesWith(tensor);
op.erase();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass() {
return std::make_unique<LowerConstantTensorsToMemrefs>();
}
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) { void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
// Lower to hybrid tensor/memref. // Lower to hybrid tensor/memref.
// The invariant of "hybrid tensor/memref" is that the core computation // The invariant of "hybrid tensor/memref" is that the core computation
@ -305,6 +390,7 @@ void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
// allocated with alloc_shape ops. // allocated with alloc_shape ops.
// Thus, shape.shape_of ops on the original tensors in the program can be // Thus, shape.shape_of ops on the original tensors in the program can be
// resolved to the shapes in the alloc_memref calls. // resolved to the shapes in the alloc_memref calls.
pm.addPass(createLowerConstantTensorsToMemrefsPass());
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass()); pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
pm.addPass(createLowerBroadcastToToLoopsPass()); pm.addPass(createLowerBroadcastToToLoopsPass());
} }

View File

@ -22,6 +22,28 @@ using namespace mlir::NPCOMP;
using mlir::LLVM::LLVMFuncOp; using mlir::LLVM::LLVMFuncOp;
using mlir::LLVM::LLVMType; using mlir::LLVM::LLVMType;
//===----------------------------------------------------------------------===//
// Utilities.
//===----------------------------------------------------------------------===//
// TODO: Move other descriptor types to here.
// Get the LLVMType for npcomprt::GlobalDescriptor.
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
return LLVMType::getStructTy(
// std::int32_t numExtents;
LLVMType::getIntNTy(llvmDialect, 32),
// std::int32_t *extents;
LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
// It is important that this struct member is a type-erased pointer
// so that this type is "context-free" and can be created in conversion
// patterns independently of the actual type of the data stored in the
// buffer.
//
// void *data;
LLVMType::getInt8PtrTy(llvmDialect));
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Compiler runtime functions. // Compiler runtime functions.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -72,6 +94,36 @@ public:
}; };
} // namespace } // namespace
namespace {
class GetGlobalOpCompilerRuntimeLowering
: public OpConversionPattern<npcomprt::GetGlobalOp> {
public:
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<npcomprt::GetGlobalOp>(backingFunc.getContext()),
backingFunc(backingFunc) {}
LogicalResult
matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto llvmDialect =
rewriter.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
// It would be nice if we could use the constructor here that takes just the
// global, but keeping track of the converted llvm.mlir.global op that gets
// created from the npcomprt.global while conversion is going on is a
// headache.
//
// Instead, we rely on the symbol name being the same and the result type
// always being the same.
auto globalAddr = rewriter.create<LLVM::AddressOfOp>(
op.getLoc(), getGlobalDescriptorTy(llvmDialect).getPointerTo(),
op.globalAttr());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc,
ValueRange({globalAddr}));
return success();
}
LLVM::LLVMFuncOp backingFunc;
};
} // namespace
// Create the LLVM runtime function backing the npcomprt op with name `name` // Create the LLVM runtime function backing the npcomprt op with name `name`
// and requiring `type`. // and requiring `type`.
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type, static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
@ -140,8 +192,164 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
"from_memref", funcTy, builder, module.getLoc()); "from_memref", funcTy, builder, module.getLoc());
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc); patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
} }
{
// Hardcoding f32 is fine here, since unranked memref descriptors have
// identical struct layout / ABI / contents regardless of the element type.
auto mlirFunctionType = builder.getFunctionType(
{getGlobalDescriptorTy(llvmDialect).getPointerTo()},
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl(
"get_global", funcTy, builder, module.getLoc());
patterns.insert<GetGlobalOpCompilerRuntimeLowering>(backingFunc);
}
} }
//===----------------------------------------------------------------------===//
// Lowering for npcomprt.global
//===----------------------------------------------------------------------===//
namespace {
class LowerNpcomprtGlobalOp : public OpConversionPattern<npcomprt::GlobalOp> {
public:
explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter)
: OpConversionPattern<npcomprt::GlobalOp>(&typeConverter.getContext()),
typeConverter(typeConverter) {}
LogicalResult
matchAndRewrite(npcomprt::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto llvmDialect = typeConverter.getDialect();
auto globalDescriptorTy = getGlobalDescriptorTy(llvmDialect);
// Create the data buffer.
auto dataBuffer = createGlobalForDenseElementsAttr(
(Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(),
op.value().cast<DenseElementsAttr>(), op, rewriter);
// Create the extents buffer.
auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>(
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
[](int64_t i) -> int32_t { return i; })));
auto extentsBuffer = createGlobalForDenseElementsAttr(
(Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32,
op, rewriter);
// Create the GlobalDescriptor.
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
op.getLoc(), globalDescriptorTy, /*isConstant=*/true,
LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&globalDescriptorGlobal.initializer());
// Create the body of the initializer.
Value globalDescriptor =
rewriter.create<LLVM::UndefOp>(op.getLoc(), globalDescriptorTy);
auto updateDescriptor = [&](Value value,
std::initializer_list<int32_t> position) {
globalDescriptor = rewriter.create<LLVM::InsertValueOp>(
op.getLoc(), globalDescriptor, value,
/*position=*/rewriter.getI32ArrayAttr(position));
};
updateDescriptor(
rewriter.create<LLVM::ConstantOp>(
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32),
rewriter.getI32IntegerAttr(
op.value().getType().cast<ShapedType>().getRank())),
{0});
// The global is actually an array, so we need to get a bare i32* pointer
// type. We could do this with GEP but it would be more verbose.
auto extentsBufferArrayAddress =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), extentsBuffer);
auto extentsBufferAddress = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
extentsBufferArrayAddress);
updateDescriptor(extentsBufferAddress, {1});
auto dataBufferAddress =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), dataBuffer);
auto typeErasedDataBufferAddress = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), LLVMType::getInt8PtrTy(llvmDialect), dataBufferAddress);
updateDescriptor(typeErasedDataBufferAddress, {2});
rewriter.create<LLVM::ReturnOp>(op.getLoc(), globalDescriptor);
rewriter.eraseOp(op);
return success();
}
private:
// TODO: It feels like MLIR core should have better utilities for this.
LLVM::GlobalOp createGlobalForDenseElementsAttr(
StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op,
ConversionPatternRewriter &rewriter) const {
auto type = elements.getType().cast<ShapedType>();
// LLVM translation doesn't handle the case of zero-sized tensors, which can
// happen e.g. for the number of extents of a rank-0 (i.e. scalar).
//
// We fake-up a size-1 DenseElementsAttr to use for creating the global.
// That takes up binary space (one element instead of zero), but that seems
// fine.
//
// TODO: LLVM translation in MLIR core should handle this case better.
if (type.getNumElements() == 0) {
auto elementType = type.getElementType();
Attribute singleElement;
if (elementType.isIntOrIndex())
singleElement = rewriter.getIntegerAttr(elementType, 0);
else if (elementType.isa<FloatType>())
singleElement = rewriter.getFloatAttr(elementType, 0);
assert(singleElement &&
"could not fake up an element for a zero element tensor");
type = RankedTensorType::get({1}, elementType);
elements =
DenseElementsAttr::get(type, ArrayRef<Attribute>(singleElement));
}
auto llvmType = getLLVMTypeForShapedType(type, op, rewriter);
return rewriter.create<LLVM::GlobalOp>(
op.getLoc(), llvmType,
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
}
LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op,
ConversionPatternRewriter &rewriter) const {
auto llvmType =
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
// MLIR->LLVM lowering for globals requires non-scalar data types. So use a
// dummy size-1 array for the scalar case.
//
// TODO: LLVM translation in MLIR core should handle this case better.
if (type.getRank() == 0)
return LLVMType::getArrayTy(llvmType, 1);
if (!llvmType) {
rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "cannot convert element type " << type.getElementType()
<< " to an LLVM type";
});
return nullptr;
}
// Construct an LLVM nested array type for the tensor initializer.
// tensor<f32> -> float
// tensor<10xf32> -> [10 x float]
// tensor<2x3xf32> -> [2 x [3 x float]]
assert(type.hasStaticShape());
auto shape = type.getShape();
while (!shape.empty()) {
llvmType = LLVMType::getArrayTy(llvmType, shape.back());
shape = shape.drop_back();
}
return llvmType;
}
LLVMTypeConverter &typeConverter;
};
} // namespace
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Lowering for module metadata // Lowering for module metadata
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -443,6 +651,7 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
target.addLegalOp<ModuleOp, ModuleTerminatorOp>(); target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
populateStdToLLVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns);
patterns.insert<LowerModuleMetadata>(context); patterns.insert<LowerModuleMetadata>(context);
patterns.insert<LowerNpcomprtGlobalOp>(converter);
if (failed(applyFullConversion(module, target, patterns))) { if (failed(applyFullConversion(module, target, patterns))) {
return signalPassFailure(); return signalPassFailure();

View File

@ -139,6 +139,39 @@ public:
}; };
} // namespace } // namespace
namespace {
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
op.value());
return success();
}
};
} // namespace
namespace {
class LowerGetGlobalMemrefOp
: public OpConversionPattern<tcp::GetGlobalMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto abiMemrefType = UnrankedMemRefType::get(
op.getType().cast<ShapedType>().getElementType(), /*memorySpace=*/0);
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
op.getLoc(), abiMemrefType, op.global());
// Cast back to the original type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
return success();
}
};
} // namespace
static LogicalResult doDialectConversion(ModuleOp module) { static LogicalResult doDialectConversion(ModuleOp module) {
auto *context = module.getContext(); auto *context = module.getContext();
@ -172,6 +205,14 @@ static LogicalResult doDialectConversion(ModuleOp module) {
target.addLegalOp<shape::FromExtentsOp>(); target.addLegalOp<shape::FromExtentsOp>();
target.addLegalOp<npcomprt::GetExtentOp>(); target.addLegalOp<npcomprt::GetExtentOp>();
patterns.insert<LowerGlobalOp>(context);
target.addIllegalOp<tcp::GlobalOp>();
target.addLegalOp<npcomprt::GlobalOp>();
patterns.insert<LowerGetGlobalMemrefOp>(context);
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
target.addLegalOp<npcomprt::GetGlobalOp>();
return applyPartialConversion(module, target, patterns); return applyPartialConversion(module, target, patterns);
} }

View File

@ -0,0 +1,60 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains data structures (which we typically call "descriptors")
// that are emitted by the compiler and must be kept in sync with the compiler
// code that creates them in LowerToLLVM.cpp.
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
#include <cstdint>
namespace npcomprt {
// All arguments are packed into this type-erased form for being invoked. See
// LowerToLLVM.cpp for more details.
typedef void ABIFunc(void **, void **);
struct FuncDescriptor {
// The length of the function name.
std::int32_t nameLen;
// The name of the function, to allow lookup.
const char *name;
// This is a raw function pointer to the function's entry point as
// emitted by the compiler.
ABIFunc *functionPtr;
// The number of inputs to the function.
std::int32_t numInputs;
// The number of outputs of the function.
std::int32_t numOutputs;
// TODO: Add arg/result descriptors and other metadata.
// With those descriptors we can do type and shape checking for each
// argument.
};
// The top-level entry point of the module metadata emitted by the
// compiler. Unlike all the other descriptors here, external code does handle
// this type (albeit through an opaque pointer).
struct ModuleDescriptor {
std::int32_t numFuncDescriptors;
FuncDescriptor *functionDescriptors;
};
// Static data representing a global variable (together with its shape).
struct GlobalDescriptor {
std::int32_t numExtents;
std::int32_t *extents;
void *data;
};
} // namespace npcomprt
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H

View File

@ -15,6 +15,7 @@
#include <cstdlib> #include <cstdlib>
#include <iostream> #include <iostream>
#include "CompilerDataStructures.h"
#include "npcomp/runtime/UserAPI.h" #include "npcomp/runtime/UserAPI.h"
using namespace npcomprt; using namespace npcomprt;
@ -118,4 +119,12 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank,
extents32Buf[i] = extents64[i]; extents32Buf[i] = extents64[i];
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank), return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, data); elementType, data);
}
extern "C" UnrankedMemref
__npcomp_compiler_rt_get_global(GlobalDescriptor *global) {
auto *descriptor = MemrefDescriptor::create(
ArrayRef<std::int32_t>(global->extents, global->numExtents),
global->data);
return UnrankedMemref{global->numExtents, descriptor};
} }

View File

@ -13,6 +13,8 @@
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include "CompilerDataStructures.h"
using namespace npcomprt; using namespace npcomprt;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -63,39 +65,7 @@ std::int32_t Tensor::getDataByteSize() const {
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Module metadata descriptors. // Module metadata descriptors.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// These descriptors are never created by runtime code. They are always
// embedded by the compiler as static data inside the module.
//
// Their definitions need to be kept in sync with the compiler code in
// LowerToLLVM.cpp
// All arguments are packed into this type-erased form for being invoked. See
// LowerToLLVM.cpp for more details.
typedef void ABIFunc(void **, void **);
namespace {
struct FuncDescriptor {
// The length of the function name.
std::int32_t nameLen;
// The name of the function, to allow lookup.
const char *name;
// This is a raw function pointer to the function's entry point as
// emitted by the compiler.
ABIFunc *functionPtr;
std::int32_t numInputs;
std::int32_t numOutputs;
// TODO: Add arg/result descriptors and other metadata.
// With those descriptors.
};
} // namespace
// The top-level entry point of the module metadata emitted by the
// compiler.
struct npcomprt::ModuleDescriptor {
std::int32_t numFuncDescriptors;
// TODO: Update compiler code to emit this as a separate global.
FuncDescriptor *functionDescriptors;
};
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Module operations. // Module operations.

View File

@ -21,3 +21,23 @@ npcomprt.module_metadata {
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32} npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
} }
func @f() { return } func @f() { return }
// -----
npcomprt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid npcomprt.global}}
npcomprt.get_global @nonexistent_symbol : memref<*xf32>
return
}
// -----
npcomprt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
npcomprt.get_global @g : memref<*xi8>
return
}

View File

@ -12,3 +12,9 @@ func @f(%arg0: !npcomprt.tensor) {
return return
} }
// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32>
npcomprt.global @g dense<0.0> : tensor<10xf32>
func @uses_global() {
npcomprt.get_global @g : memref<*xf32>
return
}

View File

@ -0,0 +1,31 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid symbol}}
tcp.get_global_memref @nonexistent_symbol : memref<3xf32>
return
}
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with shape of global}}
tcp.get_global_memref @g : memref<3xf32>
return
}
// -----
tcp.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
tcp.get_global_memref @g : memref<2xi8>
return
}

View File

@ -1,7 +1,11 @@
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail // RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
tcp.global @foo dense<0.0> : tensor<10xf32>
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) { func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
// CHECK: tcp.add // CHECK: tcp.add
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> %0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%1 = tcp.get_global_memref @foo : memref<10xf32>
return return
} }

View File

@ -0,0 +1,61 @@
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memrefs <%s | FileCheck %s
// CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
// CHECK: func @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
%0 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0 : tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Only one global is created.
// CHECK: tcp.global
// CHECK-NOT: tcp.global
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Two globals are created.
// CHECK: tcp.global
// CHECK: tcp.global
// CHECK-NOT: tcp.global
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<8.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// We don't convert non-tensor globals.
// CHECK-NOT: tcp.global
func @non_tensor() {
%0 = constant 7 : i32
return
}
// CHECK: }

View File

@ -0,0 +1,32 @@
// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm<"[3 x float]">
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm<"[1 x i32]">
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm<"{ i32, i32*, i8* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, i32*, i8* }">
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, i32*, i8* }">
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm<"[1 x i32]*">
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x i32]*"> to !llvm<"i32*">
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, i32*, i8* }">
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm<"[3 x float]*">
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm<"[3 x float]*"> to !llvm<"i8*">
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm<"{ i32, i32*, i8* }">
// CHECK: llvm.return %[[VAL_8]] : !llvm<"{ i32, i32*, i8* }">
// CHECK: }
// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm<"{ i64, i8* }"> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm<"{ i32, i32*, i8* }*">
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm<"{ i32, i32*, i8* }*">) -> !llvm<"{ i64, i8* }">
npcomprt.global @g dense<7.000000e+00> : tensor<3xf32>
func @calls_get_global() -> memref<*xf32> {
%0 = npcomprt.get_global @g : memref<*xf32>
return %0 : memref<*xf32>
}
// -----
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm<"[1 x float]">
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm<"[1 x i32]">
npcomprt.global @g dense<7.0> : tensor<f32>

View File

@ -37,6 +37,23 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
// ----- // -----
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
tcp.global @g dense<7.0> : tensor<10xf32>
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
func @gets_global() -> tensor<10xf32> {
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
// CHECK: %[[RETMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[RETMEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : !npcomprt.tensor
%0 = tcp.get_global_memref @g : memref<10xf32>
%1 = tensor_load %0 : memref<10xf32>
return %1 : tensor<10xf32>
}
// -----
// expected-error @+1 {{func not expressible with npcomprt ABI}} // expected-error @+1 {{func not expressible with npcomprt ABI}}
func @unhandled_abi_type_on_public_func(%arg0: i32) { func @unhandled_abi_type_on_public_func(%arg0: i32) {
return return

View File

@ -0,0 +1,12 @@
// RUN: npcomp-run-mlir -input %s \
// RUN: -invoke constant_add_scalar \
// RUN: -arg-value="dense<3.0> : tensor<f32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<4.000000e+00> : tensor<f32>
func @constant_add_scalar(%arg0: tensor<f32>) -> tensor<f32> {
%0 = constant dense<1.0> : tensor<f32>
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %1 : tensor<f32>
}

View File

@ -0,0 +1,12 @@
// RUN: npcomp-run-mlir -input %s \
// RUN: -invoke constant_add \
// RUN: -arg-value="dense<[3.0, 5.0]> : tensor<2xf32>" \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<[4.000000e+00, 7.000000e+00]> : tensor<2xf32>
func @constant_add(%arg0: tensor<2xf32>) -> tensor<2xf32> {
%0 = constant dense<[1.0, 2.0]> : tensor<2xf32>
%1 = "tcf.add"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
return %1 : tensor<2xf32>
}

View File

@ -0,0 +1,10 @@
// RUN: npcomp-run-mlir -input %s \
// RUN: -invoke constant \
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
// RUN: | FileCheck %s
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
func @constant() -> tensor<f32> {
%0 = constant dense<1.0> : tensor<f32>
return %0 : tensor<f32>
}

View File

@ -47,3 +47,5 @@ npctall() {
# https://superuser.com/q/253068 # https://superuser.com/q/253068
export FIGNORE=$FIGNORE:nstall-mlir export FIGNORE=$FIGNORE:nstall-mlir
export PYTHONPATH="$(realpath ${build_dir}/python):$(realpath ${build_dir}/python_native):$(realpath ${build_dir}/iree/bindings/python)"