[RefBackend] Use std.global_memref instead of homegrown thing

This vastly simplifies our code, allowing deleting multiple ops,
simplifying multiple passes, and removing a whole pass.

Now `refback` dialect is down to one op (refback.alloc_memref, which
simplifies allocations to just take a shape instead of individual
extents).
pull/119/head
Sean Silva 2020-11-10 15:14:02 -08:00
parent 6850295ec5
commit 5227d52c26
21 changed files with 5 additions and 734 deletions

View File

@ -20,21 +20,6 @@ class Refback_Op<string mnemonic, list<OpTrait> traits = []>
: Op<Refback_Dialect, mnemonic, traits> {
}
def Refback_GlobalOp : Refback_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
//===----------------------------------------------------------------------===//
// Ops related to bufferization.
//===----------------------------------------------------------------------===//
@ -52,15 +37,4 @@ def Refback_AllocMemRefOp : Refback_Op<"alloc_memref", []> {
let assemblyFormat = "$shape attr-dict `:` type($memref)";
}
def Refback_GetGlobalMemrefOp : Refback_Op<"get_global_memref"> {
let summary = "Obtain a memref pointing at the given global";
let description = [{
Obtain a memref pointing at the given global.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
}
#endif // REFBACK_OPS

View File

@ -46,43 +46,6 @@ def Refbackrt_AbortIfOp : Refbackrt_Op<"abort_if"> {
let assemblyFormat = "$pred `,` $msg attr-dict";
}
def Refbackrt_GlobalOp : Refbackrt_Op<"global", [Symbol]> {
let summary = "Represents a global variable";
let description = [{
Represents a global variable.
Currently, only constant tensors are supported, and they are not
considered to be exported.
}];
let arguments = (ins StrAttr:$sym_name, ElementsAttr:$value);
let results = (outs);
let printer = [{ return ::print$cppClass(p, *this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
def Refbackrt_GetGlobalOp : Refbackrt_Op<"get_global"> {
let summary = "Obtain a rank-erased memref pointing at the given global";
let description = [{
Obtain a rank-erased memref pointing at the given global.
TODO: As we define the runtime layer better, we should have fewer
entry points that return memrefs, or at least have a clearer separation
between the "memref world" and the "refbackrt world".
Something like forming IREE dispatch regions seems to be the missing thing:
- Everything inside the dispatch regions gets things marshaled from the
runtime (flow/hal/refbackrt) layer to/from memrefs in a clear way.
- Everything outside the dispatch regions purely uses the runtime
(flow/hal/refbackrt) data structures.
Globals should be one of the things that are purely runtime data structures,
rather than using memrefs. For now, using memrefs is simpler though.
}];
let arguments = (ins FlatSymbolRefAttr:$global);
let results = (outs AnyUnrankedMemRef:$memref);
let assemblyFormat = "$global attr-dict `:` type($memref)";
let verifier = "return ::verify$cppClass(*this);";
}
def Refbackrt_ModuleMetadataOp : Refbackrt_Op<"module_metadata", [
SingleBlockImplicitTerminator<"ModuleMetadataTerminatorOp">
]> {

View File

@ -11,15 +11,6 @@
include "mlir/Pass/PassBase.td"
def LowerConstantTensorsToMemref :
Pass<"lower-constant-tensors-to-memref", "ModuleOp"> {
let summary = "Lower std.constant of tensor type to memref";
let description = [{
This must be a module pass since it involves creating refback.global ops.
}];
let constructor = "mlir::NPCOMP::createLowerConstantTensorsToMemrefPass()";
}
def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
let summary = "Lower constructs requiring runtime support to `refbackrt`";
let description = [{
@ -30,7 +21,6 @@ def LowerToRefbackrtABI : Pass<"lower-to-refbackrt-abi", "ModuleOp"> {
The constructs requiring runtime support are:
- function signatures / module metadata
- globals
- error handling
}];
let constructor = "mlir::NPCOMP::createLowerToRefbackrtABIPass()";

View File

@ -23,9 +23,6 @@ void registerRefBackendPasses();
//
// Pass summaries are in Passes.td.
std::unique_ptr<OperationPass<ModuleOp>>
createLowerConstantTensorsToMemrefPass();
std::unique_ptr<OperationPass<FuncOp>> createLowerStructuralToMemrefPass();
std::unique_ptr<OperationPass<ModuleOp>> createLowerToRefbackrtABIPass();

View File

@ -15,49 +15,5 @@ using namespace mlir;
using namespace mlir::NPCOMP;
using namespace mlir::NPCOMP::refback;
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "refback.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalMemrefOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalMemrefOp(GetGlobalMemrefOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid symbol";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (failed(
verifyCompatibleShape(globalType.getShape(), resultType.getShape())))
return op.emitError() << "inconsistent with shape of global";
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/Refback/IR/RefbackOps.cpp.inc"

View File

@ -16,47 +16,6 @@
using namespace mlir;
using namespace mlir::NPCOMP::refbackrt;
//===----------------------------------------------------------------------===//
// GlobalOp
//===----------------------------------------------------------------------===//
static void printGlobalOp(OpAsmPrinter &p, GlobalOp &op) {
p << "refbackrt.global ";
p.printSymbolName(op.sym_name());
p << ' ';
p.printOptionalAttrDictWithKeyword(op.getAttrs(),
/*elidedAttrs=*/{"sym_name", "value"});
p.printAttribute(op.valueAttr());
}
static ParseResult parseGlobalOp(OpAsmParser &parser, OperationState &result) {
StringAttr nameAttr;
if (parser.parseSymbolName(nameAttr, mlir::SymbolTable::getSymbolAttrName(),
result.attributes))
return failure();
if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
return failure();
Attribute valueAttr;
if (parser.parseAttribute(valueAttr, "value", result.attributes))
return failure();
return success();
}
//===----------------------------------------------------------------------===//
// GetGlobalOp
//===----------------------------------------------------------------------===//
static LogicalResult verifyGetGlobalOp(GetGlobalOp op) {
auto global = SymbolTable::lookupNearestSymbolFrom<GlobalOp>(op, op.global());
if (!global)
return op.emitError() << "must reference a valid refbackrt.global";
auto globalType = global.value().getType();
auto resultType = op.getType().cast<ShapedType>();
if (globalType.getElementType() != resultType.getElementType())
return op.emitError() << "inconsistent with element type of global";
return success();
}
//===----------------------------------------------------------------------===//
// ModuleMetadataOp
//===----------------------------------------------------------------------===//

View File

@ -5,7 +5,6 @@ add_npcomp_library(NPCOMPRefBackend
RefBackend.cpp
LowerToLLVM.cpp
LowerToRefbackrtABI.cpp
TensorToMemref/LowerConstantTensorsToMemref.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SRC_DIR}/include/npcomp/RefBackend

View File

@ -56,22 +56,6 @@ static LLVMType getModuleDescriptorTy(MLIRContext *context) {
});
}
// Get the LLVMType for refbackrt::GlobalDescriptor.
static LLVMType getGlobalDescriptorTy(MLIRContext *context) {
return LLVMType::getStructTy(
// std::int32_t numExtents;
LLVMType::getIntNTy(context, 32),
// std::int32_t *extents;
LLVMType::getIntNTy(context, 32).getPointerTo(),
// It is important that this struct member is a type-erased pointer
// so that this type is "context-free" and can be created in conversion
// patterns independently of the actual type of the data stored in the
// buffer.
//
// void *data;
LLVMType::getInt8PtrTy(context));
}
//===----------------------------------------------------------------------===//
// Compiler runtime functions.
//===----------------------------------------------------------------------===//
@ -122,35 +106,6 @@ public:
};
} // namespace
namespace {
class GetGlobalOpCompilerRuntimeLowering
: public OpConversionPattern<refbackrt::GetGlobalOp> {
public:
GetGlobalOpCompilerRuntimeLowering(LLVM::LLVMFuncOp backingFunc)
: OpConversionPattern<refbackrt::GetGlobalOp>(backingFunc.getContext()),
backingFunc(backingFunc) {}
LogicalResult
matchAndRewrite(refbackrt::GetGlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// It would be nice if we could use the constructor here that takes just the
// global, but keeping track of the converted llvm.mlir.global op that gets
// created from the refbackrt.global while conversion is going on is a
// headache.
//
// Instead, we rely on the symbol name being the same and the result type
// always being the same.
auto globalAddr = rewriter.create<LLVM::AddressOfOp>(
op.getLoc(),
getGlobalDescriptorTy(rewriter.getContext()).getPointerTo(),
op.globalAttr());
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, backingFunc,
ValueRange({globalAddr}));
return success();
}
LLVM::LLVMFuncOp backingFunc;
};
} // namespace
static LLVM::GlobalOp createGlobalString(ModuleOp module, StringAttr msg,
OpBuilder &builder, Location loc) {
// TODO: Deduplicate strings.
@ -262,164 +217,8 @@ static void populateCompilerRuntimePatterns(ModuleOp module,
"from_memref", funcTy, builder, module.getLoc());
patterns.insert<FromMemrefOpCompilerRuntimeLowering>(fromMemrefFunc);
}
{
// Hardcoding f32 is fine here, since unranked memref descriptors have
// identical struct layout / ABI / contents regardless of the element type.
auto mlirFunctionType = builder.getFunctionType(
{getGlobalDescriptorTy(context).getPointerTo()},
{UnrankedMemRefType::get(builder.getF32Type(), /*memorySpace=*/0)});
LLVMType funcTy = convertFunctionType(mlirFunctionType);
LLVMFuncOp backingFunc = createCompilerRuntimeFuncDecl(
"get_global", funcTy, builder, module.getLoc());
patterns.insert<GetGlobalOpCompilerRuntimeLowering>(backingFunc);
}
}
//===----------------------------------------------------------------------===//
// Lowering for refbackrt.global
//===----------------------------------------------------------------------===//
namespace {
class LowerRefbackrtGlobalOp : public OpConversionPattern<refbackrt::GlobalOp> {
public:
explicit LowerRefbackrtGlobalOp(LLVMTypeConverter &typeConverter)
: OpConversionPattern<refbackrt::GlobalOp>(&typeConverter.getContext()),
typeConverter(typeConverter) {}
LogicalResult
matchAndRewrite(refbackrt::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto *context = rewriter.getContext();
auto globalDescriptorTy = getGlobalDescriptorTy(context);
// Create the data buffer.
auto dataBuffer = createGlobalForDenseElementsAttr(
(Twine("__refbackrt_global_data_buffer_") + op.sym_name()).str(),
op.value().cast<DenseElementsAttr>(), op, rewriter);
// Create the extents buffer.
auto extentsI32 = rewriter.getI32TensorAttr(llvm::to_vector<6>(
llvm::map_range(op.value().getType().cast<ShapedType>().getShape(),
[](int64_t i) -> int32_t { return i; })));
auto extentsBuffer = createGlobalForDenseElementsAttr(
(Twine("__refbackrt_global_extents_") + op.sym_name()).str(),
extentsI32, op, rewriter);
// Create the GlobalDescriptor.
auto globalDescriptorGlobal = rewriter.create<LLVM::GlobalOp>(
op.getLoc(), globalDescriptorTy, /*isConstant=*/true,
LLVM::Linkage::Internal, op.sym_name(), /*value=*/Attribute());
OpBuilder::InsertionGuard guard(rewriter);
rewriter.createBlock(&globalDescriptorGlobal.initializer());
// Create the body of the initializer.
Value globalDescriptor =
rewriter.create<LLVM::UndefOp>(op.getLoc(), globalDescriptorTy);
auto updateDescriptor = [&](Value value,
std::initializer_list<int32_t> position) {
globalDescriptor = rewriter.create<LLVM::InsertValueOp>(
op.getLoc(), globalDescriptor, value,
/*position=*/rewriter.getI32ArrayAttr(position));
};
updateDescriptor(
rewriter.create<LLVM::ConstantOp>(
op.getLoc(), LLVMType::getIntNTy(context, 32),
rewriter.getI32IntegerAttr(
op.value().getType().cast<ShapedType>().getRank())),
{0});
// The global is actually an array, so we need to get a bare i32* pointer
// type. We could do this with GEP but it would be more verbose.
auto extentsBufferArrayAddress =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), extentsBuffer);
auto extentsBufferAddress = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), LLVMType::getIntNTy(context, 32).getPointerTo(),
extentsBufferArrayAddress);
updateDescriptor(extentsBufferAddress, {1});
auto dataBufferAddress =
rewriter.create<LLVM::AddressOfOp>(op.getLoc(), dataBuffer);
auto typeErasedDataBufferAddress = rewriter.create<LLVM::BitcastOp>(
op.getLoc(), LLVMType::getInt8PtrTy(context), dataBufferAddress);
updateDescriptor(typeErasedDataBufferAddress, {2});
rewriter.create<LLVM::ReturnOp>(op.getLoc(), globalDescriptor);
rewriter.eraseOp(op);
return success();
}
private:
// TODO: It feels like MLIR core should have better utilities for this.
LLVM::GlobalOp createGlobalForDenseElementsAttr(
StringRef symbolName, DenseElementsAttr elements, refbackrt::GlobalOp op,
ConversionPatternRewriter &rewriter) const {
auto type = elements.getType().cast<ShapedType>();
// LLVM translation doesn't handle the case of zero-sized tensors, which can
// happen e.g. for the number of extents of a rank-0 (i.e. scalar).
//
// We fake-up a size-1 DenseElementsAttr to use for creating the global.
// That takes up binary space (one element instead of zero), but that seems
// fine.
//
// TODO: LLVM translation in MLIR core should handle this case better.
if (type.getNumElements() == 0) {
auto elementType = type.getElementType();
Attribute singleElement;
if (elementType.isIntOrIndex())
singleElement = rewriter.getIntegerAttr(elementType, 0);
else if (elementType.isa<FloatType>())
singleElement = rewriter.getFloatAttr(elementType, 0);
assert(singleElement &&
"could not fake up an element for a zero element tensor");
type = RankedTensorType::get({1}, elementType);
elements =
DenseElementsAttr::get(type, ArrayRef<Attribute>(singleElement));
}
auto llvmType = getLLVMTypeForShapedType(type, op, rewriter);
return rewriter.create<LLVM::GlobalOp>(
op.getLoc(), llvmType,
/*isConstant=*/true, LLVM::Linkage::Internal, symbolName, elements);
}
LLVMType getLLVMTypeForShapedType(ShapedType type, refbackrt::GlobalOp op,
ConversionPatternRewriter &rewriter) const {
auto llvmType =
typeConverter.convertType(type.getElementType()).cast<LLVMType>();
// MLIR->LLVM lowering for globals requires non-scalar data types. So use a
// dummy size-1 array for the scalar case.
//
// TODO: LLVM translation in MLIR core should handle this case better.
if (type.getRank() == 0)
return LLVMType::getArrayTy(llvmType, 1);
if (!llvmType) {
rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
diag << "cannot convert element type " << type.getElementType()
<< " to an LLVM type";
});
return nullptr;
}
// Construct an LLVM nested array type for the tensor initializer.
// tensor<f32> -> float
// tensor<10xf32> -> [10 x float]
// tensor<2x3xf32> -> [2 x [3 x float]]
assert(type.hasStaticShape());
auto shape = type.getShape();
while (!shape.empty()) {
llvmType = LLVMType::getArrayTy(llvmType, shape.back());
shape = shape.drop_back();
}
return llvmType;
}
LLVMTypeConverter &typeConverter;
};
} // namespace
//===----------------------------------------------------------------------===//
// Lowering for module metadata
//===----------------------------------------------------------------------===//
@ -706,7 +505,6 @@ class LowerToLLVM : public LowerToLLVMBase<LowerToLLVM> {
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
populateStdToLLVMConversionPatterns(converter, patterns);
patterns.insert<LowerModuleMetadata>(context);
patterns.insert<LowerRefbackrtGlobalOp>(converter);
// TODO: Move these "std to std" legalizations to their own pass if we grow
// lots of these patterns.

View File

@ -74,37 +74,6 @@ static LogicalResult createModuleMetadata(ModuleOp module) {
// Dialect conversion.
//===----------------------------------------------------------------------===//
namespace {
class LowerGlobalOp : public OpConversionPattern<refback::GlobalOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(refback::GlobalOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<refbackrt::GlobalOp>(op, op.sym_name(),
op.value());
return success();
}
};
} // namespace
namespace {
class LowerGetGlobalMemrefOp
: public OpConversionPattern<refback::GetGlobalMemrefOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(refback::GetGlobalMemrefOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto abiMemref = rewriter.create<refbackrt::GetGlobalOp>(
op.getLoc(), getABIMemrefType(op.getType()), op.global());
// Cast back to the original type.
rewriter.replaceOpWithNewOp<MemRefCastOp>(op, abiMemref, op.getType());
return success();
}
};
} // namespace
namespace {
class LowerAssertOp : public OpConversionPattern<AssertOp> {
public:
@ -216,12 +185,6 @@ static LogicalResult doDialectConversion(ModuleOp module) {
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return typeConverter.isLegal(op); });
patterns.insert<LowerGlobalOp>(context);
target.addIllegalOp<refback::GlobalOp>();
patterns.insert<LowerGetGlobalMemrefOp>(context);
target.addIllegalOp<refback::GetGlobalMemrefOp>();
patterns.insert<LowerAssertOp>(context);
target.addIllegalOp<AssertOp>();

View File

@ -228,8 +228,7 @@ void mlir::NPCOMP::createRefBackendLoweringPipeline(
// Bufferize the TCP dialect.
pm.addNestedPass<FuncOp>(createTCPBufferizePass());
// Lower tensor-valued constants to refback.global.
pm.addPass(createLowerConstantTensorsToMemrefPass());
pm.addPass(createTensorConstantBufferizePass());
// refback::AllocMemRefOp takes a shape (i.e. extent tensor) as an argument.
// We need to resolve this to std.alloc which takes individual extents.
pm.addNestedPass<FuncOp>(createLowerAllocMemRefOpsPass());

View File

@ -48,13 +48,6 @@ struct ModuleDescriptor {
FuncDescriptor *functionDescriptors;
};
// Static data representing a global variable (together with its shape).
struct GlobalDescriptor {
std::int32_t numExtents;
std::int32_t *extents;
void *data;
};
} // namespace refbackrt
#endif // NPCOMP_LIB_RUNTIME_COMPILERDATASTRUCTURES_H

View File

@ -114,11 +114,3 @@ __npcomp_compiler_rt_from_memref(std::int64_t rank,
return Tensor::createRaw(ArrayRef<std::int32_t>(extents32Buf.data(), rank),
elementType, data);
}
extern "C" UnrankedMemref
__npcomp_compiler_rt_get_global(GlobalDescriptor *global) {
auto *descriptor = MemrefDescriptor::create(
ArrayRef<std::int32_t>(global->extents, global->numExtents),
global->data);
return UnrankedMemref{global->numExtents, descriptor};
}

View File

@ -1,113 +0,0 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "../PassDetail.h"
#include "npcomp/RefBackend/RefBackend.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/DialectConversion.h"
#include "npcomp/Dialect/Refback/IR/RefbackDialect.h"
#include "npcomp/Dialect/Refback/IR/RefbackOps.h"
using namespace mlir;
using namespace mlir::NPCOMP;
//===----------------------------------------------------------------------===//
// LowerConstantTensorsToMemref
//===----------------------------------------------------------------------===//
namespace {
// This class creates global ops for all tensor-valued constants in the program.
// It creates them with pretty names and makes sure that duplicate globals
// aren't created.
class GlobalCreator {
public:
explicit GlobalCreator(ModuleOp module);
refback::GlobalOp getGlobalFor(Attribute attr) {
assert(globals.find(attr) != globals.end() && "unknown constant attr");
return globals[attr];
}
private:
DenseMap<Attribute, refback::GlobalOp> globals;
};
GlobalCreator::GlobalCreator(ModuleOp module) {
// Create a builder without an insertion point. We will insert using the
// symbol table to guarantee unique names.
OpBuilder globalBuilder(module.getContext());
SymbolTable symbolTable(module);
module.walk([&](ConstantOp op) {
// We only want tensor constants for now.
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return;
// If we already have a global for this constant value, no need to do
// anything else.
auto it = globals.find(op.getValue());
if (it != globals.end())
return;
// Create a pretty name.
SmallString<64> buf;
llvm::raw_svector_ostream os(buf);
interleave(type.getShape(), os, "x");
os << "x" << type.getElementType();
auto global = globalBuilder.create<refback::GlobalOp>(
op.getLoc(), (Twine("__constant_") + os.str()).str(),
op.getValue().cast<ElementsAttr>());
symbolTable.insert(global);
// The symbol table inserts at the end of the module, but globals are a bit
// nicer if they are at the beginning.
global.getOperation()->moveBefore(&module.front());
globals[op.getValue()] = global;
});
}
} // namespace
namespace {
class LowerConstantTensorsToMemref
: public LowerConstantTensorsToMemrefBase<LowerConstantTensorsToMemref> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<refback::RefbackDialect>();
}
void runOnOperation() override {
auto module = getOperation();
GlobalCreator globals(module);
// With the global traversal factored into GlobalCreator, this could in
// principle be done with a pattern.
module.walk([&](ConstantOp op) {
auto type = op.getType().dyn_cast<RankedTensorType>();
if (!type)
return;
auto global = globals.getGlobalFor(op.getValue());
OpBuilder builder(op);
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
auto memref = builder.create<refback::GetGlobalMemrefOp>(
op.getLoc(), memrefType, global.getName());
Value tensor = builder.create<TensorLoadOp>(op.getLoc(), type, memref);
op.replaceAllUsesWith(tensor);
op.erase();
});
}
};
} // namespace
std::unique_ptr<OperationPass<ModuleOp>>
mlir::NPCOMP::createLowerConstantTensorsToMemrefPass() {
return std::make_unique<LowerConstantTensorsToMemref>();
}

View File

@ -1,31 +0,0 @@
// RUN: npcomp-opt -split-input-file -verify-diagnostics <%s
// -----
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid symbol}}
refback.get_global_memref @nonexistent_symbol : memref<3xf32>
return
}
// -----
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with shape of global}}
refback.get_global_memref @g : memref<3xf32>
return
}
// -----
refback.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
refback.get_global_memref @g : memref<2xi8>
return
}

View File

@ -1,11 +1,8 @@
// RUN: npcomp-opt <%s | npcomp-opt | FileCheck %s
// CHECK-LABEL: refback.global @foo dense<0.0{{.*}}> : tensor<10xf32>
refback.global @foo dense<0.0> : tensor<10xf32>
// CHECK-LABEL: func @global
func @global() {
// CHECK: refback.get_global_memref @foo : memref<10xf32>
%0 = refback.get_global_memref @foo : memref<10xf32>
// CHECK-LABEL: @alloc_memref
func @alloc_memref(%arg0: tensor<?xindex>) {
// CHECK: refback.alloc_memref
%0 = refback.alloc_memref %arg0 : memref<?xf32>
return
}

View File

@ -21,23 +21,3 @@ refbackrt.module_metadata {
refbackrt.func_metadata {funcName = @f, numInputs = 0 : i32, numOutputs = 1 : i32}
}
func @f() { return }
// -----
refbackrt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{must reference a valid refbackrt.global}}
refbackrt.get_global @nonexistent_symbol : memref<*xf32>
return
}
// -----
refbackrt.global @g dense<0.0> : tensor<2xf32>
func @f() {
// expected-error @+1 {{inconsistent with element type of global}}
refbackrt.get_global @g : memref<*xi8>
return
}

View File

@ -11,10 +11,3 @@ refbackrt.module_metadata {
func @f(%arg0: !refbackrt.tensor) {
return
}
// CHECK-LABEL: refbackrt.global @g dense<0.0{{.*}}> : tensor<10xf32>
refbackrt.global @g dense<0.0> : tensor<10xf32>
func @uses_global() {
refbackrt.get_global @g : memref<*xf32>
return
}

View File

@ -1,61 +0,0 @@
// RUN: npcomp-opt -split-input-file -lower-constant-tensors-to-memref <%s | FileCheck %s
// CHECK-LABEL: module {
// We check the debug name too since we put some effort into making that readable.
// The name isn't load-bearing though.
// CHECK: refback.global @__constant_3x4xf32 dense<7.000000e+00> : tensor<3x4xf32>
// CHECK: func @basic
func @basic() -> tensor<3x4xf32> {
// CHECK: %[[MEMREF:.*]] = refback.get_global_memref @__constant_3x4xf32 : memref<3x4xf32>
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF]]
%0 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0 : tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Only one global is created.
// CHECK: refback.global
// CHECK-NOT: refback.global
func @duplicate_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<7.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// Two globals are created.
// CHECK: refback.global
// CHECK: refback.global
// CHECK-NOT: refback.global
func @multiple_constants() -> (tensor<3x4xf32>, tensor<3x4xf32>) {
%0 = constant dense<7.0> : tensor<3x4xf32>
%1 = constant dense<8.0> : tensor<3x4xf32>
// CHECK: return %[[TENSOR]]
return %0, %1 : tensor<3x4xf32>, tensor<3x4xf32>
}
// CHECK: }
// -----
// CHECK-LABEL: module {
// We don't convert non-tensor globals.
// CHECK-NOT: refback.global
func @non_tensor() {
%0 = constant 7 : i32
return
}
// CHECK: }

View File

@ -1,58 +0,0 @@
// RUN: npcomp-opt -refback-lower-to-llvm -split-input-file <%s | FileCheck %s --dump-input=fail
// CHECK-LABEL: llvm.func @malloc(!llvm.i64) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<3xf32>) : !llvm.array<3 x float>
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<3> : tensor<1xi32>) : !llvm.array<1 x i32>
// CHECK-LABEL: llvm.mlir.global internal constant @g() : !llvm.struct<(i32, ptr<i32>, ptr<i8>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.undef : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_1:.*]] = llvm.mlir.constant(1 : i32) : !llvm.i32
// CHECK: %[[VAL_2:.*]] = llvm.insertvalue %[[VAL_1]], %[[VAL_0]][0 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_3:.*]] = llvm.mlir.addressof @__refbackrt_global_extents_g : !llvm.ptr<array<1 x i32>>
// CHECK: %[[VAL_4:.*]] = llvm.bitcast %[[VAL_3]] : !llvm.ptr<array<1 x i32>> to !llvm.ptr<i32>
// CHECK: %[[VAL_5:.*]] = llvm.insertvalue %[[VAL_4]], %[[VAL_2]][1 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: %[[VAL_6:.*]] = llvm.mlir.addressof @__refbackrt_global_data_buffer_g : !llvm.ptr<array<3 x float>>
// CHECK: %[[VAL_7:.*]] = llvm.bitcast %[[VAL_6]] : !llvm.ptr<array<3 x float>> to !llvm.ptr<i8>
// CHECK: %[[VAL_8:.*]] = llvm.insertvalue %[[VAL_7]], %[[VAL_5]][2 : i32] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: llvm.return %[[VAL_8]] : !llvm.struct<(i32, ptr<i32>, ptr<i8>)>
// CHECK: }
// CHECK-LABEL: llvm.func @calls_get_global() -> !llvm.struct<(i64, ptr<i8>)> {
// CHECK: %[[VAL_0:.*]] = llvm.mlir.addressof @g : !llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>
// CHECK: %[[VAL_1:.*]] = llvm.call @__npcomp_compiler_rt_get_global(%[[VAL_0]]) : (!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_2:.*]] = llvm.mlir.constant(1 : index) : !llvm.i64
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// CHECK: %[[VAL_5:.*]] = llvm.mlir.constant(8 : index) : !llvm.i64
// CHECK: %[[VAL_6:.*]] = llvm.mul %[[VAL_3]], %[[VAL_4]] : !llvm.i64
// CHECK: %[[VAL_7:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_8:.*]] = llvm.mul %[[VAL_3]], %[[VAL_7]] : !llvm.i64
// CHECK: %[[VAL_9:.*]] = llvm.add %[[VAL_8]], %[[VAL_2]] : !llvm.i64
// CHECK: %[[VAL_10:.*]] = llvm.mul %[[VAL_9]], %[[VAL_5]] : !llvm.i64
// CHECK: %[[VAL_11:.*]] = llvm.add %[[VAL_6]], %[[VAL_10]] : !llvm.i64
// CHECK: %[[VAL_12:.*]] = llvm.mlir.constant(false) : !llvm.i1
// CHECK: %[[VAL_13:.*]] = llvm.call @malloc(%[[VAL_11]]) : (!llvm.i64) -> !llvm.ptr<i8>
// CHECK: %[[VAL_14:.*]] = llvm.extractvalue %[[VAL_1]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: "llvm.intr.memcpy"(%[[VAL_13]], %[[VAL_14]], %[[VAL_11]], %[[VAL_12]]) : (!llvm.ptr<i8>, !llvm.ptr<i8>, !llvm.i64, !llvm.i1) -> ()
// CHECK: %[[VAL_15:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_16:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_17:.*]] = llvm.insertvalue %[[VAL_16]], %[[VAL_15]][0] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: %[[VAL_18:.*]] = llvm.insertvalue %[[VAL_13]], %[[VAL_17]][1] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.return %[[VAL_18]] : !llvm.struct<(i64, ptr<i8>)>
// CHECK: }
refbackrt.global @g dense<7.000000e+00> : tensor<3xf32>
func @calls_get_global() -> memref<*xf32> {
%0 = refbackrt.get_global @g : memref<*xf32>
return %0 : memref<*xf32>
}
// -----
// For scalars, we have to fake-up a size-1 data buffer array to make LLVM translation happy.
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_data_buffer_g(dense<7.000000e+00> : tensor<f32>) : !llvm.array<1 x float>
// CHECK: llvm.mlir.global internal constant @__refbackrt_global_extents_g(dense<0> : tensor<1xi32>) : !llvm.array<1 x i32>
refbackrt.global @g dense<7.0> : tensor<f32>

View File

@ -17,7 +17,6 @@
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_identity("identity")
// CHECK-LABEL: llvm.mlir.global internal constant @__npcomp_func_descriptors() : !llvm.array<1 x struct<(i32, ptr<i8>, ptr<i8>, i32, i32)>> {
@ -114,7 +113,6 @@ func @identity(%arg0: !refbackrt.tensor) -> !refbackrt.tensor {
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results0("inputs1results0")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results1("inputs1results1")
// CHECK: llvm.mlir.global internal constant @__npcomp_internal_constant_inputs1results2("inputs1results2")
@ -215,7 +213,6 @@ func @inputs1results2(%arg0: !refbackrt.tensor) -> (!refbackrt.tensor, !refbackr
// CHECK: llvm.func @__npcomp_compiler_rt_abort_if(!llvm.i1, !llvm.ptr<i8>)
// CHECK: llvm.func @__npcomp_compiler_rt_to_memref(!llvm.ptr<i8>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK: llvm.func @__npcomp_compiler_rt_from_memref(!llvm.i64, !llvm.ptr<i8>) -> !llvm.ptr<i8>
// CHECK: llvm.func @__npcomp_compiler_rt_get_global(!llvm.ptr<struct<(i32, ptr<i32>, ptr<i8>)>>) -> !llvm.struct<(i64, ptr<i8>)>
// CHECK-LABEL: llvm.func @calls_abort_if(
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.i1) {

View File

@ -73,22 +73,6 @@ func @multiple_blocks(%arg0: memref<?xf32>) -> memref<?xf32> {
// -----
// CHECK: refbackrt.global @g dense<7.000000e+00> : tensor<10xf32>
refback.global @g dense<7.0> : tensor<10xf32>
// CHECK-LABEL: func @gets_global() -> !refbackrt.tensor
func @gets_global() -> memref<10xf32> {
// CHECK: %[[GMEMREF:.*]] = refbackrt.get_global @g : memref<*xf32>
// CHECK: %[[ORIGMEMREF:.*]] = memref_cast %[[GMEMREF]] : memref<*xf32> to memref<10xf32>
// CHECK: %[[OUTABIMEMREF:.*]] = memref_cast %[[ORIGMEMREF:.*]] : memref<10xf32> to memref<*xf32>
// CHECK: %[[RET:.*]] = refbackrt.from_memref %[[OUTABIMEMREF]] : memref<*xf32>
// CHECK: return %[[RET]] : !refbackrt.tensor
%0 = refback.get_global_memref @g : memref<10xf32>
return %0 : memref<10xf32>
}
// -----
// Test diagnostics.
// expected-error @+1 {{func not expressible with refbackrt ABI}}