mirror of https://github.com/llvm/torch-mlir
npcomprt: add support for constants
- create tcp.global + tcp.get_global_memref - create npcomprt.global + npcomprt.get_global - LLVM lowering for new npcomprt ops - Runtime: - GlobalDescriptor struct emitted by LLVM lowering - implement __npcomp_compiler_rt_get_global Also, - cleanly isolate all runtime data structure definitions shared by the compiler and runtime into lib/runtime/CompilerDataStructures.hpull/1/head
parent
2e40ce05ad
commit
e228aa4b11
|
@ -27,7 +27,7 @@ def Npcomprt_Tensor
|
|||
: DialectType<
|
||||
Npcomprt_Dialect,
|
||||
CPred<"$_self.isa<::mlir::NPCOMP::npcomprt::TensorType>()">,
|
||||
"buffer view">,
|
||||
"npcomprt.tensor">,
|
||||
BuildableType<
|
||||
"$_builder.getType<::mlir::NPCOMP::npcomprt::TensorType>()"> {
|
||||
let typeDescription = [{The runtime type that represents a buffer.}];
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace NPCOMP {
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
#define NPCOMPRT_OPS
|
||||
|
||||
include "npcomp/Dialect/Npcomprt/IR/NpcomprtBase.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
class Npcomprt_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<Npcomprt_Dialect, mnemonic, traits> {
|
||||
|
@ -57,6 +58,44 @@ def Npcomprt_AbortIfOp : Npcomprt_Op<"abort_if"> {
|
|||
let assemblyFormat = "$pred attr-dict";
|
||||
}
|
||||
|
||||
def Npcomprt_GlobalOp : Npcomprt_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Npcomprt_GetGlobalOp : Npcomprt_Op<"get_global"> {
|
||||
let summary = "Obtain a rank-erased memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a rank-erased memref pointing at the given global.
|
||||
|
||||
TODO: As we define the runtime layer better, we should have fewer
|
||||
entry points that return memrefs, or at least have a clearer separation
|
||||
between the "memref world" and the "npcomprt world".
|
||||
Something like forming IREE dispatch regions seems to be the missing thing:
|
||||
- Everything inside the dispatch regions gets things marshaled from the
|
||||
runtime (flow/hal/npcomprt) layer to/from memrefs in a clear way.
|
||||
- Everything outside the dispatch regions purely uses the runtime
|
||||
(flow/hal/npcomprt) data structures.
|
||||
Globals should be one of the things that are purely runtime data structures,
|
||||
rather than using memrefs. For now, using memrefs is simpler though.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyUnrankedMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
// TODO: verify exists and shape is compatible
|
||||
}
|
||||
|
||||
def Npcomprt_ModuleMetadataOp : Npcomprt_Op<"module_metadata", [
|
||||
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
|
||||
]> {
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -13,6 +13,7 @@ include "npcomp/Dialect/TCP/IR/TCPBase.td"
|
|||
include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/IR/SymbolInterfaces.td"
|
||||
|
||||
class TCP_Op<string mnemonic, list<OpTrait> traits = []>
|
||||
: Op<TCP_Dialect, mnemonic, traits> {
|
||||
|
@ -58,6 +59,32 @@ Allocates a memref of the given shape.
|
|||
let assemblyFormat = "$shape attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
def TCP_GlobalOp : TCP_Op<"global", [Symbol]> {
|
||||
let summary = "Represents a global variable";
|
||||
let description = [{
|
||||
Represents a global variable.
|
||||
|
||||
Currently, only constant tensors are supported, and they are not
|
||||
considered to be exported.
|
||||
}];
|
||||
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
|
||||
let results = (outs);
|
||||
|
||||
let printer = [{ return ::print$cppClass(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def TCP_GetGlobalMemrefOp : TCP_Op<"get_global_memref"> {
|
||||
let summary = "Obtain a memref pointing at the given global";
|
||||
let description = [{
|
||||
Obtain a memref pointing at the given global.
|
||||
}];
|
||||
let arguments = (ins FlatSymbolRefAttr:$global);
|
||||
let results = (outs AnyMemRef:$memref);
|
||||
let assemblyFormat = "$global attr-dict `:` type($memref)";
|
||||
let verifier = "return ::verify$cppClass(*this);";
|
||||
}
|
||||
|
||||
// TODO: Change to a more principled error handling mechanism.
|
||||
// This op probably doesn't need to exist eventually.
|
||||
// This op is also not correctly modeled right now, since it itself doesn't
|
||||
|
|
|
@ -25,6 +25,9 @@ std::unique_ptr<OperationPass<FuncOp>> createLowerBroadcastToToLoopsPass();
|
|||
std::unique_ptr<OperationPass<FuncOp>>
|
||||
createLowerLinalgOnTensorToLinalgOnMemrefPass();
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
createLowerConstantTensorsToMemrefsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createResolveShapeOfOpsPass();
|
||||
|
||||
std::unique_ptr<OperationPass<FuncOp>> createResolveTensorLoadStoreOpsPass();
|
||||
|
|
|
@ -23,6 +23,15 @@ def LowerBroadcastToToLoops :
|
|||
let constructor = "mlir::NPCOMP::createLowerBroadcastToToLoopsPass()";
|
||||
}
|
||||
|
||||
def LowerConstantTensorsToMemrefs :
|
||||
Pass<"lower-constant-tensors-to-memrefs", "ModuleOp"> {
|
||||
let summary = "Lower std.constant of tensor type to hybrid tensor/memref.";
|
||||
let description = [{
|
||||
This has to be a module pass since it involves creating tcp.global ops.
|
||||
}];
|
||||
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass()";
|
||||
}
|
||||
|
||||
def ResolveShapeOfOps : Pass<"resolve-shape-of-ops", "FuncOp"> {
|
||||
let summary = "Resolve shape.shape_of ops to other shapes.";
|
||||
let constructor = "mlir::NPCOMP::createResolveShapeOfOpsPass()";
|
||||
|
|
|
@ -10,11 +10,53 @@
|
|||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/SymbolTable.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "npcomp/Dialect/Npcomprt/IR/NpcomprtDialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP::npcomprt;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "npcomprt.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid npcomprt.global";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleMetadataOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -8,11 +8,56 @@
|
|||
|
||||
#include "npcomp/Dialect/TCP/IR/TCPOps.h"
|
||||
#include "mlir/Dialect/Shape/IR/Shape.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::NPCOMP;
|
||||
using namespace mlir::NPCOMP::tcp;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
|
||||
p << "tcp.global ";
|
||||
p.printSymbolName(op.sym_name());
|
||||
p << ' ';
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
|
||||
/*elidedAttrs=*/{"sym_name", "value"});
|
||||
p.printAttribute(op.valueAttr());
|
||||
}
|
||||
|
||||
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
|
||||
StringAttr nameAttr;
|
||||
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
|
||||
result.attributes))
|
||||
return failure();
|
||||
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
|
||||
return failure();
|
||||
Attribute valueAttr;
|
||||
if (parser.parseAttribute(valueAttr, "value", result.attributes))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GetGlobalMemrefOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
|
||||
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
|
||||
if (!global)
|
||||
return op.emitError() << "must reference a valid symbol";
|
||||
auto globalType = global.value().getType();
|
||||
auto resultType = op.getType().cast<ShapedType>();
|
||||
if (failed(
|
||||
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
|
||||
return op.emitError() << "inconsistent with shape of global";
|
||||
if (globalType.getElementType() != resultType.getElementType())
|
||||
return op.emitError() << "inconsistent with element type of global";
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeObserveErrorOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -359,7 +359,8 @@ void mlir::NPCOMP::createE2ELoweringPipeline(
|
|||
pm.addPass(createResolveTensorLoadStoreOpsPass());
|
||||
|
||||
// At this point, the IR is in a form where there are no tensor ops
|
||||
// (except tensor_store's of arguments and tensor_load's of returns).
|
||||
// (except tensor_store's of arguments, tensor_load's of returns, and
|
||||
// constants).
|
||||
//
|
||||
// This is a reasonable representation for doing buffer assignment.
|
||||
// TODO: Do buffer assignment here.
|
||||
|
|
|
@ -296,6 +296,91 @@ mlir::NPCOMP::createLowerLinalgOnTensorToLinalgOnMemrefPass() {
|
|||
return std::make_unique<LowerLinalgOnTensorToLinalgOnMemref>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LowerConstantTensorsToMemrefs
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
// This class creates global ops for all tensor-valued constants in the program.
|
||||
// It creates them with pretty names and makes sure that duplicate globals
|
||||
// aren't created.
|
||||
class GlobalCreator {
|
||||
public:
|
||||
explicit GlobalCreator(ModuleOp module);
|
||||
tcp::GlobalOp getGlobalFor(Attribute attr) {
|
||||
assert(globals.find(attr) != globals.end() && "unknown constant attr");
|
||||
return globals[attr];
|
||||
}
|
||||
|
||||
private:
|
||||
DenseMap<Attribute, tcp::GlobalOp> globals;
|
||||
};
|
||||
|
||||
GlobalCreator::GlobalCreator(ModuleOp module) {
|
||||
// Create a builder without an insertion point. We will insert using the
|
||||
// symbol table to guarantee unique names.
|
||||
OpBuilder globalBuilder(module.getContext());
|
||||
SymbolTable symbolTable(module);
|
||||
module.walk([&](ConstantOp op) {
|
||||
// We only want tensor constants for now.
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
// If we already have a global for this constant value, no need to do
|
||||
// anything else.
|
||||
auto it = globals.find(op.getValue());
|
||||
if (it != globals.end())
|
||||
return;
|
||||
|
||||
// Create a pretty name.
|
||||
SmallString<64> buf;
|
||||
llvm::raw_svector_ostream os(buf);
|
||||
interleave(type.getShape(), os, "x");
|
||||
os << "x" << type.getElementType();
|
||||
|
||||
auto global = globalBuilder.create<tcp::GlobalOp>(
|
||||
op.getLoc(), (Twine("__constant_") + os.str()).str(),
|
||||
op.getValue().cast<ElementsAttr>());
|
||||
symbolTable.insert(global);
|
||||
// The symbol table inserts at the end of the module, but globals are a bit
|
||||
// nicer if they are at the beginning.
|
||||
global.getOperation()->moveBefore(&module.front());
|
||||
globals[op.getValue()] = global;
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerConstantTensorsToMemrefs
|
||||
: public LowerConstantTensorsToMemrefsBase<LowerConstantTensorsToMemrefs> {
|
||||
void runOnOperation() {
|
||||
auto module = getOperation();
|
||||
GlobalCreator globals(module);
|
||||
|
||||
// With the global traversal factored into GlobalCreator, this could in
|
||||
// principle be done with a pattern.
|
||||
module.walk([&](ConstantOp op) {
|
||||
auto type = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
auto global = globals.getGlobalFor(op.getValue());
|
||||
OpBuilder builder(op);
|
||||
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
|
||||
auto memref = builder.create<tcp::GetGlobalMemrefOp>(
|
||||
op.getLoc(), memrefType, global.getName());
|
||||
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
|
||||
op.replaceAllUsesWith(tensor);
|
||||
op.erase();
|
||||
});
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>>
|
||||
mlir::NPCOMP::createLowerConstantTensorsToMemrefsPass() {
|
||||
return std::make_unique<LowerConstantTensorsToMemrefs>();
|
||||
}
|
||||
|
||||
void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
||||
// Lower to hybrid tensor/memref.
|
||||
// The invariant of "hybrid tensor/memref" is that the core computation
|
||||
|
@ -305,6 +390,7 @@ void mlir::NPCOMP::createLowerToHybridTensorMemRefPipeline(OpPassManager &pm) {
|
|||
// allocated with alloc_shape ops.
|
||||
// Thus, shape.shape_of ops on the original tensors in the program can be
|
||||
// resolved to the shapes in the alloc_memref calls.
|
||||
pm.addPass(createLowerConstantTensorsToMemrefsPass());
|
||||
pm.addPass(createLowerLinalgOnTensorToLinalgOnMemrefPass());
|
||||
pm.addPass(createLowerBroadcastToToLoopsPass());
|
||||
}
|
||||
|
|
|
@ -22,6 +22,28 @@ using namespace mlir::NPCOMP;
|
|||
using mlir::LLVM::LLVMFuncOp;
|
||||
using mlir::LLVM::LLVMType;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utilities.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO: Move other descriptor types to here.
|
||||
|
||||
// Get the LLVMType for npcomprt::GlobalDescriptor.
|
||||
static LLVMType getGlobalDescriptorTy(LLVM::LLVMDialect *llvmDialect) {
|
||||
return LLVMType::getStructTy(
|
||||
// std::int32_t numExtents;
|
||||
LLVMType::getIntNTy(llvmDialect, 32),
|
||||
// std::int32_t *extents;
|
||||
LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
|
||||
// It is important that this struct member is a type-erased pointer
|
||||
// so that this type is "context-free" and can be created in conversion
|
||||
// patterns independently of the actual type of the data stored in the
|
||||
// buffer.
|
||||
//
|
||||
// void *data;
|
||||
LLVMType::getInt8PtrTy(llvmDialect));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Compiler runtime functions.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -72,6 +94,36 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class GetGlobalOpCompilerRuntimeLowering
|
||||
: public OpConversionPattern<npcomprt::GetGlobalOp> {
|
||||
public:
|
||||
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
|
||||
: OpConversionPattern<npcomprt::GetGlobalOp>(backingFunc.getContext()),
|
||||
backingFunc(backingFunc) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(npcomprt::GetGlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto llvmDialect =
|
||||
rewriter.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
|
||||
// It would be nice if we could use the constructor here that takes just the
|
||||
// global, but keeping track of the converted llvm.mlir.global op that gets
|
||||
// created from the npcomprt.global while conversion is going on is a
|
||||
// headache.
|
||||
//
|
||||
// Instead, we rely on the symbol name being the same and the result type
|
||||
// always being the same.
|
||||
auto globalAddr = rewriter.create<LLVM::AddressOfOp>(
|
||||
op.getLoc(), getGlobalDescriptorTy(llvmDialect).getPointerTo(),
|
||||
op.globalAttr());
|
||||
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc,
|
||||
ValueRange({globalAddr}));
|
||||
return success();
|
||||
}
|
||||
LLVM::LLVMFuncOp backingFunc;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// Create the LLVM runtime function backing the npcomprt op with name `name`
|
||||
// and requiring `type`.
|
||||
static LLVMFuncOp createCompilerRuntimeFuncDecl(StringRef name, LLVMType type,
|
||||
|
@ -140,8 +192,164 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
|
|||
"from_memref", funcTy, builder, module.getLoc());
|
||||
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
|
||||
}
|
||||
|
||||
{
|
||||
// Hardcoding f32 is fine here, since unranked memref descriptors have
|
||||
// identical struct layout / ABI / contents regardless of the element type.
|
||||
auto mlirFunctionType = builder.getFunctionType(
|
||||
{getGlobalDescriptorTy(llvmDialect).getPointerTo()},
|
||||
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
|
||||
LLVMType funcTy = convertFunctionType(mlirFunctionType);
|
||||
LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl(
|
||||
"get_global", funcTy, builder, module.getLoc());
|
||||
patterns.insert<GetGlobalOpCompilerRuntimeLowering>(backingFunc);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering for npcomprt.global
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
class LowerNpcomprtGlobalOp : public OpConversionPattern<npcomprt::GlobalOp> {
|
||||
public:
|
||||
explicit LowerNpcomprtGlobalOp(LLVMTypeConverter &typeConverter)
|
||||
: OpConversionPattern<npcomprt::GlobalOp>(&typeConverter.getContext()),
|
||||
typeConverter(typeConverter) {}
|
||||
LogicalResult
|
||||
matchAndRewrite(npcomprt::GlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto llvmDialect = typeConverter.getDialect();
|
||||
auto globalDescriptorTy = getGlobalDescriptorTy(llvmDialect);
|
||||
|
||||
// Create the data buffer.
|
||||
auto dataBuffer = createGlobalForDenseElementsAttr(
|
||||
(Twine("__npcomprt_global_data_buffer_") + op.sym_name()).str(),
|
||||
op.value().cast<DenseElementsAttr>(), op, rewriter);
|
||||
|
||||
// Create the extents buffer.
|
||||
auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>(
|
||||
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
|
||||
[](int64_t i) -> int32_t { return i; })));
|
||||
auto extentsBuffer = createGlobalForDenseElementsAttr(
|
||||
(Twine("__npcomprt_global_extents_") + op.sym_name()).str(), extentsI32,
|
||||
op, rewriter);
|
||||
|
||||
// Create the GlobalDescriptor.
|
||||
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
|
||||
op.getLoc(), globalDescriptorTy, /*isConstant=*/true,
|
||||
LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute());
|
||||
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.createBlock(&globalDescriptorGlobal.initializer());
|
||||
|
||||
// Create the body of the initializer.
|
||||
Value globalDescriptor =
|
||||
rewriter.create<LLVM::UndefOp>(op.getLoc(), globalDescriptorTy);
|
||||
auto updateDescriptor = [&](Value value,
|
||||
std::initializer_list<int32_t> position) {
|
||||
globalDescriptor = rewriter.create<LLVM::InsertValueOp>(
|
||||
op.getLoc(), globalDescriptor, value,
|
||||
/*position=*/rewriter.getI32ArrayAttr(position));
|
||||
};
|
||||
updateDescriptor(
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32),
|
||||
rewriter.getI32IntegerAttr(
|
||||
op.value().getType().cast<ShapedType>().getRank())),
|
||||
{0});
|
||||
|
||||
// The global is actually an array, so we need to get a bare i32* pointer
|
||||
// type. We could do this with GEP but it would be more verbose.
|
||||
auto extentsBufferArrayAddress =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), extentsBuffer);
|
||||
auto extentsBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||
op.getLoc(), LLVMType::getIntNTy(llvmDialect, 32).getPointerTo(),
|
||||
extentsBufferArrayAddress);
|
||||
updateDescriptor(extentsBufferAddress, {1});
|
||||
|
||||
auto dataBufferAddress =
|
||||
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), dataBuffer);
|
||||
auto typeErasedDataBufferAddress = rewriter.create<LLVM::BitcastOp>(
|
||||
op.getLoc(), LLVMType::getInt8PtrTy(llvmDialect), dataBufferAddress);
|
||||
updateDescriptor(typeErasedDataBufferAddress, {2});
|
||||
rewriter.create<LLVM::ReturnOp>(op.getLoc(), globalDescriptor);
|
||||
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
// TODO: It feels like MLIR core should have better utilities for this.
|
||||
LLVM::GlobalOp createGlobalForDenseElementsAttr(
|
||||
StringRef symbolName, DenseElementsAttr elements, npcomprt::GlobalOp op,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto type = elements.getType().cast<ShapedType>();
|
||||
|
||||
// LLVM translation doesn't handle the case of zero-sized tensors, which can
|
||||
// happen e.g. for the number of extents of a rank-0 (i.e. scalar).
|
||||
//
|
||||
// We fake-up a size-1 DenseElementsAttr to use for creating the global.
|
||||
// That takes up binary space (one element instead of zero), but that seems
|
||||
// fine.
|
||||
//
|
||||
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||
if (type.getNumElements() == 0) {
|
||||
auto elementType = type.getElementType();
|
||||
Attribute singleElement;
|
||||
if (elementType.isIntOrIndex())
|
||||
singleElement = rewriter.getIntegerAttr(elementType, 0);
|
||||
else if (elementType.isa<FloatType>())
|
||||
singleElement = rewriter.getFloatAttr(elementType, 0);
|
||||
assert(singleElement &&
|
||||
"could not fake up an element for a zero element tensor");
|
||||
type = RankedTensorType::get({1}, elementType);
|
||||
elements =
|
||||
DenseElementsAttr::get(type, ArrayRef<Attribute>(singleElement));
|
||||
}
|
||||
|
||||
auto llvmType = getLLVMTypeForShapedType(type, op, rewriter);
|
||||
return rewriter.create<LLVM::GlobalOp>(
|
||||
op.getLoc(), llvmType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
|
||||
}
|
||||
|
||||
LLVMType getLLVMTypeForShapedType(ShapedType type, npcomprt::GlobalOp op,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto llvmType =
|
||||
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
|
||||
|
||||
// MLIR->LLVM lowering for globals requires non-scalar data types. So use a
|
||||
// dummy size-1 array for the scalar case.
|
||||
//
|
||||
// TODO: LLVM translation in MLIR core should handle this case better.
|
||||
if (type.getRank() == 0)
|
||||
return LLVMType::getArrayTy(llvmType, 1);
|
||||
|
||||
if (!llvmType) {
|
||||
rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
|
||||
diag << "cannot convert element type " << type.getElementType()
|
||||
<< " to an LLVM type";
|
||||
});
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Construct an LLVM nested array type for the tensor initializer.
|
||||
// tensor<f32> -> float
|
||||
// tensor<10xf32> -> [10 x float]
|
||||
// tensor<2x3xf32> -> [2 x [3 x float]]
|
||||
assert(type.hasStaticShape());
|
||||
auto shape = type.getShape();
|
||||
while (!shape.empty()) {
|
||||
llvmType = LLVMType::getArrayTy(llvmType, shape.back());
|
||||
shape = shape.drop_back();
|
||||
}
|
||||
return llvmType;
|
||||
}
|
||||
LLVMTypeConverter &typeConverter;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Lowering for module metadata
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -443,6 +651,7 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
|
|||
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
patterns.insert<LowerModuleMetadata>(context);
|
||||
patterns.insert<LowerNpcomprtGlobalOp>(converter);
|
||||
|
||||
if (failed(applyFullConversion(module, target, patterns))) {
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -139,6 +139,39 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerGlobalOp : public OpConversionPattern<tcp::GlobalOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::GlobalOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<npcomprt::GlobalOp>(op, op.sym_name(),
|
||||
op.value());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
class LowerGetGlobalMemrefOp
|
||||
: public OpConversionPattern<tcp::GetGlobalMemrefOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(tcp::GetGlobalMemrefOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto abiMemrefType = UnrankedMemRefType::get(
|
||||
op.getType().cast<ShapedType>().getElementType(), /*memorySpace=*/0);
|
||||
auto abiMemref = rewriter.create<npcomprt::GetGlobalOp>(
|
||||
op.getLoc(), abiMemrefType, op.global());
|
||||
// Cast back to the original type.
|
||||
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult doDialectConversion(ModuleOp module) {
|
||||
auto *context = module.getContext();
|
||||
|
||||
|
@ -172,6 +205,14 @@ static LogicalResult doDialectConversion(ModuleOp module) {
|
|||
target.addLegalOp<shape::FromExtentsOp>();
|
||||
target.addLegalOp<npcomprt::GetExtentOp>();
|
||||
|
||||
patterns.insert<LowerGlobalOp>(context);
|
||||
target.addIllegalOp<tcp::GlobalOp>();
|
||||
target.addLegalOp<npcomprt::GlobalOp>();
|
||||
|
||||
patterns.insert<LowerGetGlobalMemrefOp>(context);
|
||||
target.addIllegalOp<tcp::GetGlobalMemrefOp>();
|
||||
target.addLegalOp<npcomprt::GetGlobalOp>();
|
||||
|
||||
return applyPartialConversion(module, target, patterns);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains data structures (which we typically call "descriptors")
|
||||
// that are emitted by the compiler and must be kept in sync with the compiler
|
||||
// code that creates them in LowerToLLVM.cpp.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||
#define NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
namespace npcomprt {
|
||||
|
||||
// All arguments are packed into this type-erased form for being invoked. See
|
||||
// LowerToLLVM.cpp for more details.
|
||||
typedef void ABIFunc(void **, void **);
|
||||
|
||||
struct FuncDescriptor {
|
||||
// The length of the function name.
|
||||
std::int32_t nameLen;
|
||||
// The name of the function, to allow lookup.
|
||||
const char *name;
|
||||
// This is a raw function pointer to the function's entry point as
|
||||
// emitted by the compiler.
|
||||
ABIFunc *functionPtr;
|
||||
// The number of inputs to the function.
|
||||
std::int32_t numInputs;
|
||||
// The number of outputs of the function.
|
||||
std::int32_t numOutputs;
|
||||
// TODO: Add arg/result descriptors and other metadata.
|
||||
// With those descriptors we can do type and shape checking for each
|
||||
// argument.
|
||||
};
|
||||
|
||||
// The top-level entry point of the module metadata emitted by the
|
||||
// compiler. Unlike all the other descriptors here, external code does handle
|
||||
// this type (albeit through an opaque pointer).
|
||||
struct ModuleDescriptor {
|
||||
std::int32_t numFuncDescriptors;
|
||||
FuncDescriptor *functionDescriptors;
|
||||
};
|
||||
|
||||
// Static data representing a global variable (together with its shape).
|
||||
struct GlobalDescriptor {
|
||||
std::int32_t numExtents;
|
||||
std::int32_t *extents;
|
||||
void *data;
|
||||
};
|
||||
|
||||
} // namespace npcomprt
|
||||
|
||||
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H
|
|
@ -15,6 +15,7 @@
|
|||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
|
||||
#include "CompilerDataStructures.h"
|
||||
#include "npcomp/runtime/UserAPI.h"
|
||||
|
||||
using namespace npcomprt;
|
||||
|
@ -118,4 +119,12 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank,
|
|||
extents32Buf[i] = extents64[i];
|
||||
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
|
||||
elementType, data);
|
||||
}
|
||||
|
||||
extern "C" UnrankedMemref
|
||||
__npcomp_compiler_rt_get_global(GlobalDescriptor *global) {
|
||||
auto *descriptor = MemrefDescriptor::create(
|
||||
ArrayRef<std::int32_t>(global->extents, global->numExtents),
|
||||
global->data);
|
||||
return UnrankedMemref{global->numExtents, descriptor};
|
||||
}
|
|
@ -13,6 +13,8 @@
|
|||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
#include "CompilerDataStructures.h"
|
||||
|
||||
using namespace npcomprt;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -63,39 +65,7 @@ std::int32_t Tensor::getDataByteSize() const {
|
|||
//===----------------------------------------------------------------------===//
|
||||
// Module metadata descriptors.
|
||||
//===----------------------------------------------------------------------===//
|
||||
// These descriptors are never created by runtime code. They are always
|
||||
// embedded by the compiler as static data inside the module.
|
||||
//
|
||||
// Their definitions need to be kept in sync with the compiler code in
|
||||
// LowerToLLVM.cpp
|
||||
|
||||
// All arguments are packed into this type-erased form for being invoked. See
|
||||
// LowerToLLVM.cpp for more details.
|
||||
typedef void ABIFunc(void **, void **);
|
||||
|
||||
namespace {
|
||||
struct FuncDescriptor {
|
||||
// The length of the function name.
|
||||
std::int32_t nameLen;
|
||||
// The name of the function, to allow lookup.
|
||||
const char *name;
|
||||
// This is a raw function pointer to the function's entry point as
|
||||
// emitted by the compiler.
|
||||
ABIFunc *functionPtr;
|
||||
std::int32_t numInputs;
|
||||
std::int32_t numOutputs;
|
||||
// TODO: Add arg/result descriptors and other metadata.
|
||||
// With those descriptors.
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// The top-level entry point of the module metadata emitted by the
|
||||
// compiler.
|
||||
struct npcomprt::ModuleDescriptor {
|
||||
std::int32_t numFuncDescriptors;
|
||||
// TODO: Update compiler code to emit this as a separate global.
|
||||
FuncDescriptor *functionDescriptors;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Module operations.
|
||||
|
|
|
@ -21,3 +21,23 @@ npcomprt.module_metadata {
|
|||
npcomprt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
|
||||
}
|
||||
func @f() { return }
|
||||
|
||||
// -----
|
||||
|
||||
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{must reference a valid npcomprt.global}}
|
||||
npcomprt.get_global @nonexistent_symbol : memref<*xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
npcomprt.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with element type of global}}
|
||||
npcomprt.get_global @g : memref<*xi8>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -12,3 +12,9 @@ func @f(%arg0: !npcomprt.tensor) {
|
|||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: npcomprt.global @g dense<0.0{{.*}}> : tensor<10xf32>
|
||||
npcomprt.global @g dense<0.0> : tensor<10xf32>
|
||||
func @uses_global() {
|
||||
npcomprt.get_global @g : memref<*xf32>
|
||||
return
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s
|
||||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{must reference a valid symbol}}
|
||||
tcp.get_global_memref @nonexistent_symbol : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with shape of global}}
|
||||
tcp.get_global_memref @g : memref<3xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
tcp.global @g dense<0.0> : tensor<2xf32>
|
||||
|
||||
func @f() {
|
||||
// expected-error @+1 {{inconsistent with element type of global}}
|
||||
tcp.get_global_memref @g : memref<2xi8>
|
||||
return
|
||||
}
|
|
@ -1,7 +1,11 @@
|
|||
// RUN: npcomp-opt <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK-LABEL: tcp.global @foo dense<0.0{{.*}}> : tensor<10xf32>
|
||||
tcp.global @foo dense<0.0> : tensor<10xf32>
|
||||
|
||||
func @f(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>, %arg2: i32) {
|
||||
// CHECK: tcp.add
|
||||
%0 = "tcp.add"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
|
||||
%1 = tcp.get_global_memref @foo : memref<10xf32>
|
||||
return
|
||||
}
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memrefs <%s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
// We check the debug name too since we put some effort into making that readable.
|
||||
// The name isn't load-bearing though.
|
||||
// CHECK: tcp.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
|
||||
// CHECK: func @basic
|
||||
func @basic() -> tensor<3x4xf32> {
|
||||
// CHECK: %[[MEMREF:.*]] = tcp.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
|
||||
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0 : tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
|
||||
// Only one global is created.
|
||||
// CHECK: tcp.global
|
||||
// CHECK-NOT: tcp.global
|
||||
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<7.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
|
||||
// Two globals are created.
|
||||
// CHECK: tcp.global
|
||||
// CHECK: tcp.global
|
||||
// CHECK-NOT: tcp.global
|
||||
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
|
||||
%0 = constant dense<7.0> : tensor<3x4xf32>
|
||||
%1 = constant dense<8.0> : tensor<3x4xf32>
|
||||
// CHECK: return %[[TENSOR]]
|
||||
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
|
||||
}
|
||||
|
||||
// CHECK: }
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: module {
|
||||
// We don't convert non-tensor globals.
|
||||
// CHECK-NOT: tcp.global
|
||||
func @non_tensor() {
|
||||
%0 = constant 7 : i32
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: }
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: npcomp-opt -e2e-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
|
||||
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm<"[3 x float]">
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm<"[1 x i32]">
|
||||
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm<"{ i32, i32*, i8* }"> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm<"{ i32, i32*, i8* }">
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
|
||||
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__npcomprt_global_extents_g : !llvm<"[1 x i32]*">
|
||||
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm<"[1 x i32]*"> to !llvm<"i32*">
|
||||
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__npcomprt_global_data_buffer_g : !llvm<"[3 x float]*">
|
||||
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm<"[3 x float]*"> to !llvm<"i8*">
|
||||
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm<"{ i32, i32*, i8* }">
|
||||
// CHECK: llvm.return %[[VAL_8]] : !llvm<"{ i32, i32*, i8* }">
|
||||
// CHECK: }
|
||||
// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm<"{ i64, i8* }"> {
|
||||
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm<"{ i32, i32*, i8* }*">
|
||||
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm<"{ i32, i32*, i8* }*">) -> !llvm<"{ i64, i8* }">
|
||||
npcomprt.global @g dense<7.000000e+00> : tensor<3xf32>
|
||||
func @calls_get_global() -> memref<*xf32> {
|
||||
%0 = npcomprt.get_global @g : memref<*xf32>
|
||||
return %0 : memref<*xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm<"[1 x float]">
|
||||
// CHECK: llvm.mlir.global internal constant @__npcomprt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm<"[1 x i32]">
|
||||
npcomprt.global @g dense<7.0> : tensor<f32>
|
||||
|
|
@ -37,6 +37,23 @@ func @basic(%arg0: tensor<?xf32>) -> tensor<?xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
|
||||
// CHECK: npcomprt.global @g dense<7.000000e+00> : tensor<10xf32>
|
||||
tcp.global @g dense<7.0> : tensor<10xf32>
|
||||
// CHECK-LABEL: func @gets_global() -> !npcomprt.tensor
|
||||
func @gets_global() -> tensor<10xf32> {
|
||||
// CHECK: %[[GMEMREF:.*]] = npcomprt.get_global @g : memref<*xf32>
|
||||
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
|
||||
// CHECK: %[[RETMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
|
||||
// CHECK: %[[RET:.*]] = npcomprt.from_memref %[[RETMEMREF]] : memref<*xf32>
|
||||
// CHECK: return %[[RET]] : !npcomprt.tensor
|
||||
%0 = tcp.get_global_memref @g : memref<10xf32>
|
||||
%1 = tensor_load %0 : memref<10xf32>
|
||||
return %1 : tensor<10xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{func not expressible with npcomprt ABI}}
|
||||
func @unhandled_abi_type_on_public_func(%arg0: i32) {
|
||||
return
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: npcomp-run-mlir -input %s \
|
||||
// RUN: -invoke constant_add_scalar \
|
||||
// RUN: -arg-value="dense<3.0> : tensor<f32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<4.000000e+00> : tensor<f32>
|
||||
func @constant_add_scalar(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = constant dense<1.0> : tensor<f32>
|
||||
%1 = "tcf.add"(%arg0, %0) : (tensor<f32>, tensor<f32>) -> tensor<f32>
|
||||
return %1 : tensor<f32>
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
// RUN: npcomp-run-mlir -input %s \
|
||||
// RUN: -invoke constant_add \
|
||||
// RUN: -arg-value="dense<[3.0, 5.0]> : tensor<2xf32>" \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<[4.000000e+00, 7.000000e+00]> : tensor<2xf32>
|
||||
func @constant_add(%arg0: tensor<2xf32>) -> tensor<2xf32> {
|
||||
%0 = constant dense<[1.0, 2.0]> : tensor<2xf32>
|
||||
%1 = "tcf.add"(%arg0, %0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
|
||||
return %1 : tensor<2xf32>
|
||||
}
|
|
@ -0,0 +1,10 @@
|
|||
// RUN: npcomp-run-mlir -input %s \
|
||||
// RUN: -invoke constant \
|
||||
// RUN: -shared-libs=%npcomp_runtime_shlib 2>&1 \
|
||||
// RUN: | FileCheck %s
|
||||
|
||||
// CHECK: output #0: dense<1.000000e+00> : tensor<f32>
|
||||
func @constant() -> tensor<f32> {
|
||||
%0 = constant dense<1.0> : tensor<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
|
@ -47,3 +47,5 @@ npctall() {
|
|||
# https://superuser.com/q/253068
|
||||
export FIGNORE=$FIGNORE:nstall-mlir
|
||||
|
||||
export PYTHONPATH="$(realpath ${build_dir}/python):$(realpath ${build_dir}/python_native):$(realpath ${build_dir}/iree/bindings/python)"
|
||||
|
||||
|
|
Loading…
Reference in New Issue